From 25ed58328aa418e88cdc78b6961c9b0b015557e3 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:29:14 +0200 Subject: [PATCH 001/120] [management] fix network map dns filter (#4547) --- management/server/dns.go | 30 ++---------------------------- management/server/dns_test.go | 9 --------- management/server/types/account.go | 1 - 3 files changed, 2 insertions(+), 38 deletions(-) diff --git a/management/server/dns.go b/management/server/dns.go index f6f0201d3..6b73dbd0e 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -20,29 +20,9 @@ import ( // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { - CustomZones sync.Map NameServerGroups sync.Map } -// GetCustomZone retrieves a cached custom zone -func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) { - if c == nil { - return nil, false - } - if value, ok := c.CustomZones.Load(key); ok { - return value.(*proto.CustomZone), true - } - return nil, false -} - -// SetCustomZone stores a custom zone in the cache -func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) { - if c == nil { - return - } - c.CustomZones.Store(key, value) -} - // GetNameServerGroup retrieves a cached name server group func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { if c == nil { @@ -212,14 +192,8 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC } for _, zone := range update.CustomZones { - cacheKey := zone.Domain - if cachedZone, exists := cache.GetCustomZone(cacheKey); exists { - protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone) - } else { - protoZone := convertToProtoCustomZone(zone) - cache.SetCustomZone(cacheKey, protoZone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) - } + protoZone := convertToProtoCustomZone(zone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) } for _, nsGroup := range update.NameServerGroups { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index d58689544..55a1bbe66 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -474,15 +474,6 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Results should be different for different inputs") } - // Verify that the cache contains elements from both configs - if _, exists := cache.GetCustomZone("example.com"); !exists { - t.Errorf("Cache should contain custom zone for example.com") - } - - if _, exists := cache.GetCustomZone("example.org"); !exists { - t.Errorf("Cache should contain custom zone for example.org") - } - if _, exists := cache.GetNameServerGroup("group1"); !exists { t.Errorf("Cache should contain name server group 'group1'") } diff --git a/management/server/types/account.go b/management/server/types/account.go index ca075b9f6..a69d3bb08 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -300,7 +300,6 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone - if peersCustomZone.Domain != "" { records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) zones = append(zones, nbdns.CustomZone{ From 17bab881f78bc49efe68354a1e16e8870c05b530 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Fri, 26 Sep 2025 16:42:18 +0700 Subject: [PATCH 002/120] [client] Add Windows DNS Policies To GPO Path Always (#4460) [client] Add Windows DNS Policies To GPO Path Always (#4460) --- client/internal/dns/host_windows.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index fdc2c3063..0d3f033fb 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -240,15 +240,17 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 for i, domain := range domains { - policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) - if r.gpo { - policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) - } singleDomain := []string{domain} - if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err) + if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) + } + + if r.gpo { + if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure gpo DNS policy: %w", err) + } } log.Debugf("added NRPT entry for domain: %s", domain) @@ -401,6 +403,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err)) } @@ -412,6 +415,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err)) } From e8d301fdc9357b9ca9ecf9ce40d4d347255d3bf9 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 30 Sep 2025 15:31:18 +0200 Subject: [PATCH 003/120] [client] Fix/pkg loss (#3338) The Relayed connection setup is optimistic. It does not have any confirmation of an established end-to-end connection. Peers start sending WireGuard handshake packets immediately after the successful offer-answer handshake. Meanwhile, for successful P2P connection negotiation, we change the WireGuard endpoint address, but this change does not trigger new handshake initiation. Because the peer switched from Relayed connection to P2P, the packets from the Relay server are dropped and must wait for the next WireGuard handshake via P2P. To avoid this scenario, the relayed WireGuard proxy no longer drops the packets. Instead, it rewrites the source address to the new P2P endpoint and continues forwarding the packets. We still have one corner case: if the Relayed server negotiation chooses a server that has not been used before. In this case, one side of the peer connection will be slower to reach the Relay server, and the Relay server will drop the handshake packet. If everything goes well we should see exactly 5 seconds improvements between the WireGuard configuration time and the handshake time. --- client/iface/bind/endpoint.go | 14 +- client/iface/bind/ice_bind.go | 15 +- client/iface/iface_new_freebsd.go | 41 +++++ .../{iface_new_unix.go => iface_new_linux.go} | 2 +- client/iface/wgproxy/bind/proxy.go | 107 ++++++++---- client/iface/wgproxy/ebpf/proxy.go | 59 ++----- client/iface/wgproxy/ebpf/wrapper.go | 79 +++++---- client/iface/wgproxy/factory_kernel.go | 1 - .../iface/wgproxy/factory_kernel_freebsd.go | 31 ---- client/iface/wgproxy/proxy.go | 5 + client/iface/wgproxy/proxy_linux_test.go | 104 +++++++----- client/iface/wgproxy/proxy_seed_test.go | 39 +++++ client/iface/wgproxy/proxy_test.go | 152 ++++++++++++++---- client/iface/wgproxy/rawsocket/rawsocket.go | 50 ++++++ client/iface/wgproxy/udp/proxy.go | 94 ++++++++--- client/iface/wgproxy/udp/rawsocket.go | 101 ++++++++++++ client/internal/peer/conn.go | 62 ++++--- client/internal/peer/endpoint.go | 105 ++++++++++++ 18 files changed, 784 insertions(+), 277 deletions(-) create mode 100644 client/iface/iface_new_freebsd.go rename client/iface/{iface_new_unix.go => iface_new_linux.go} (97%) delete mode 100644 client/iface/wgproxy/factory_kernel_freebsd.go create mode 100644 client/iface/wgproxy/proxy_seed_test.go create mode 100644 client/iface/wgproxy/rawsocket/rawsocket.go create mode 100644 client/iface/wgproxy/udp/rawsocket.go create mode 100644 client/internal/peer/endpoint.go diff --git a/client/iface/bind/endpoint.go b/client/iface/bind/endpoint.go index 1926ff88f..caa92f05d 100644 --- a/client/iface/bind/endpoint.go +++ b/client/iface/bind/endpoint.go @@ -1,5 +1,17 @@ package bind -import wgConn "golang.zx2c4.com/wireguard/conn" +import ( + "net" + + wgConn "golang.zx2c4.com/wireguard/conn" +) type Endpoint = wgConn.StdNetEndpoint + +func EndpointToUDPAddr(e Endpoint) *net.UDPAddr { + return &net.UDPAddr{ + IP: e.Addr().AsSlice(), + Port: int(e.Port()), + Zone: e.Addr().Zone(), + } +} diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 577c7c0c4..ef630b9d0 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,6 +1,7 @@ package bind import ( + "context" "encoding/binary" "fmt" "net" @@ -42,7 +43,7 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD // use the port because in the Send function the wgConn.Endpoint the port info is not exported. type ICEBind struct { *wgConn.StdNetBind - RecvChan chan RecvMessage + recvChan chan RecvMessage transportNet transport.Net filterFn udpmux.FilterFn @@ -65,7 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, - RecvChan: make(chan RecvMessage, 1), + recvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, endpoints: make(map[netip.Addr]net.Conn), @@ -155,6 +156,14 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } +func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) { + select { + case <-ctx.Done(): + return + case b.recvChan <- msg: + } +} + func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() @@ -271,7 +280,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo select { case <-c.closedChan: return 0, net.ErrClosed - case msg, ok := <-c.RecvChan: + case msg, ok := <-c.recvChan: if !ok { return 0, net.ErrClosed } diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go new file mode 100644 index 000000000..86ed14ce1 --- /dev/null +++ b/client/iface/iface_new_freebsd.go @@ -0,0 +1,41 @@ +//go:build freebsd + +package iface + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{} + + if netstack.IsEnabled() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + return nil, fmt.Errorf("couldn't check or load tun module") +} diff --git a/client/iface/iface_new_unix.go b/client/iface/iface_new_linux.go similarity index 97% rename from client/iface/iface_new_unix.go rename to client/iface/iface_new_linux.go index 493144f13..77fd30fae 100644 --- a/client/iface/iface_new_unix.go +++ b/client/iface/iface_new_linux.go @@ -1,4 +1,4 @@ -//go:build (linux && !android) || freebsd +//go:build linux && !android package iface diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index bf6da72c2..dbc694e91 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -16,28 +16,37 @@ import ( "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) +type IceBind interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) + Recv(ctx context.Context, msg bind.RecvMessage) + MTU() uint16 +} + type ProxyBind struct { - Bind *bind.ICEBind + bind IceBind - fakeNetIP *netip.AddrPort - wgBindEndpoint *bind.Endpoint - remoteConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address + wgRelayedEndpoint *bind.Endpoint + wgCurrentUsed *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } -func NewProxyBind(bind *bind.ICEBind) *ProxyBind { +func NewProxyBind(bind IceBind) *ProxyBind { p := &ProxyBind{ - Bind: bind, + bind: bind, closeListener: listener.NewCloseListener(), + pausedCond: sync.NewCond(&sync.Mutex{}), } return p @@ -46,25 +55,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind { // AddTurnConn adds a new connection to the bind. // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // WireGuard configuration. +// +// Parameters: +// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages +// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address +// - remoteConn: The established TURN connection to the remote peer func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { fakeNetIP, err := fakeAddress(nbAddr) if err != nil { return err } - - p.fakeNetIP = fakeNetIP - p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} + p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) return nil } + func (p *ProxyBind) EndpointAddr() *net.UDPAddr { - return &net.UDPAddr{ - IP: p.fakeNetIP.Addr().AsSlice(), - Port: int(p.fakeNetIP.Port()), - Zone: p.fakeNetIP.Addr().Zone(), - } + return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint) } func (p *ProxyBind) SetDisconnectListener(disconnected func()) { @@ -76,17 +85,21 @@ func (p *ProxyBind) Work() { return } - p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) + p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgCurrentUsed = p.wgRelayedEndpoint // Start the proxy only once if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyBind) Pause() { @@ -94,9 +107,25 @@ func (p *ProxyBind) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgCurrentUsed = addrToEndpoint(endpoint) + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() +} + +func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { + ip, _ := netip.AddrFromSlice(addr.IP.To4()) + addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort} } func (p *ProxyBind) CloseConn() error { @@ -107,6 +136,10 @@ func (p *ProxyBind) CloseConn() error { } func (p *ProxyBind) close() error { + if p.remoteConn == nil { + return nil + } + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -120,7 +153,12 @@ func (p *ProxyBind) close() error { p.cancel() - p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr()) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr @@ -136,7 +174,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { }() for { - buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead) + buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { @@ -147,18 +185,17 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } msg := bind.RecvMessage{ - Endpoint: p.wgBindEndpoint, + Endpoint: p.wgCurrentUsed, Buffer: buf[:n], } - p.Bind.RecvChan <- msg - p.pausedMu.Unlock() + p.bind.Recv(ctx, msg) + p.pausedCond.L.Unlock() } } diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index b899f1694..858143091 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -6,9 +6,7 @@ import ( "context" "fmt" "net" - "os" "sync" - "syscall" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -18,6 +16,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bufsize" + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" nbnet "github.com/netbirdio/netbird/client/net" @@ -27,6 +26,10 @@ const ( loopbackAddr = "127.0.0.1" ) +var ( + localHostNetIP = net.ParseIP("127.0.0.1") +) + // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int @@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { return err } - p.rawConn, err = p.prepareSenderRawSocket() + p.rawConn, err = rawsocket.PrepareSenderRawSocket() if err != nil { return err } @@ -214,57 +217,17 @@ generatePort: return p.lastUsedPort, nil } -func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { - // Create a raw socket. - fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) - if err != nil { - return nil, fmt.Errorf("creating raw socket failed: %w", err) - } - - // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. - err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) - if err != nil { - return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) - } - - // Bind the socket to the "lo" interface. - err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") - if err != nil { - return nil, fmt.Errorf("binding to lo interface failed: %w", err) - } - - // Set the fwmark on the socket. - err = nbnet.SetSocketOpt(fd) - if err != nil { - return nil, fmt.Errorf("setting fwmark failed: %w", err) - } - - // Convert the file descriptor to a PacketConn. - file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) - if file == nil { - return nil, fmt.Errorf("converting fd to file failed") - } - packetConn, err := net.FilePacketConn(file) - if err != nil { - return nil, fmt.Errorf("converting file to packet conn failed: %w", err) - } - - return packetConn, nil -} - -func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { - localhost := net.ParseIP("127.0.0.1") - +func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { payload := gopacket.Payload(data) ipH := &layers.IPv4{ - DstIP: localhost, - SrcIP: localhost, + DstIP: localHostNetIP, + SrcIP: endpointAddr.IP, Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, } udpH := &layers.UDP{ - SrcPort: layers.UDPPort(port), + SrcPort: layers.UDPPort(endpointAddr.Port), DstPort: layers.UDPPort(p.localWGListenPort), } @@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { if err != nil { return fmt.Errorf("serialize layers: %w", err) } - if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { + if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { return fmt.Errorf("write to raw conn: %w", err) } return nil diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 3d71b01bd..ff44d30c0 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -18,41 +18,42 @@ import ( // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call type ProxyWrapper struct { - WgeBPFProxy *WGEBPFProxy + wgeBPFProxy *WGEBPFProxy remoteConn net.Conn ctx context.Context cancel context.CancelFunc - wgEndpointAddr *net.UDPAddr + wgRelayedEndpointAddr *net.UDPAddr + wgEndpointCurrentUsedAddr *net.UDPAddr - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } -func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { +func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { return &ProxyWrapper{ - WgeBPFProxy: WgeBPFProxy, + wgeBPFProxy: proxy, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } } - func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { - addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) + addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) } p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) - p.wgEndpointAddr = addr + p.wgRelayedEndpointAddr = addr return err } func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { - return p.wgEndpointAddr + return p.wgRelayedEndpointAddr } func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { @@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyWrapper) Pause() { @@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() { } log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgEndpointCurrentUsedAddr = endpoint + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map -func (e *ProxyWrapper) CloseConn() error { - if e.cancel == nil { +func (p *ProxyWrapper) CloseConn() error { + if p.cancel == nil { return fmt.Errorf("proxy not started") } - e.cancel() + p.cancel() - e.closeListener.SetCloseListener(nil) + p.closeListener.SetCloseListener(nil) - if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - return fmt.Errorf("close remote conn: %w", err) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + return fmt.Errorf("failed to close remote conn: %w", err) } return nil } func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { - defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port)) - buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead) + buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead) for { n, err := p.readFromRemote(ctx, buf) if err != nil { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) - p.pausedMu.Unlock() + err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { @@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err } p.closeListener.Notify() if !errors.Is(err, io.EOF) { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err) } return 0, err } diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 63bc2ed24..ad2807546 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy { } return ebpf.NewProxyWrapper(w.ebpfProxy) - } func (w *KernelFactory) Free() error { diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go deleted file mode 100644 index 039f1cd3a..000000000 --- a/client/iface/wgproxy/factory_kernel_freebsd.go +++ /dev/null @@ -1,31 +0,0 @@ -package wgproxy - -import ( - log "github.com/sirupsen/logrus" - - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" -) - -// KernelFactory todo: check eBPF support on FreeBSD -type KernelFactory struct { - wgPort int - mtu uint16 -} - -func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory { - log.Infof("WireGuard Proxy Factory will produce UDP proxy") - f := &KernelFactory{ - wgPort: wgPort, - mtu: mtu, - } - - return f -} - -func (w *KernelFactory) GetProxy() Proxy { - return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu) -} - -func (w *KernelFactory) Free() error { - return nil -} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go index c2879877e..3c8dfd30e 100644 --- a/client/iface/wgproxy/proxy.go +++ b/client/iface/wgproxy/proxy.go @@ -11,6 +11,11 @@ type Proxy interface { EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint Work() // Work start or resume the proxy Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. + + //RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused + //and rewrite the src address to the endpoint address. + //With this logic can avoid the package loss from relayed connections. + RedirectAs(endpoint *net.UDPAddr) CloseConn() error SetDisconnectListener(disconnected func()) } diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go index 5add503e1..9526e91d2 100644 --- a/client/iface/wgproxy/proxy_linux_test.go +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -3,54 +3,82 @@ package wgproxy import ( - "context" - "os" - "testing" + "fmt" + "net" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/iface/wgproxy/udp" ) -func TestProxyCloseByRemoteConnEBPF(t *testing.T) { - if os.Getenv("GITHUB_ACTIONS") != "true" { - t.Skip("Skipping test as it requires root privileges") - } - ctx := context.Background() +func seedProxies() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - - tests := []struct { - name string - proxy Proxy - }{ - { - name: "ebpf proxy", - proxy: &ebpf.ProxyWrapper{ - WgeBPFProxy: ebpfProxy, - }, - }, + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, } + pl = append(pl, pEbpf) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) - if err != nil { - t.Errorf("error: %v", err) - } - - _ = relayedConn.Close() - if err := tt.proxy.CloseConn(); err != nil { - t.Errorf("error: %v", err) - } - }) + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, } + pl = append(pl, pUDP) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + + ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) + if err := ebpfProxy.Listen(); err != nil { + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) + } + + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, + } + pl = append(pl, pEbpf) + + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, + } + pl = append(pl, pUDP) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + + return pl, nil } diff --git a/client/iface/wgproxy/proxy_seed_test.go b/client/iface/wgproxy/proxy_seed_test.go new file mode 100644 index 000000000..4d244f18a --- /dev/null +++ b/client/iface/wgproxy/proxy_seed_test.go @@ -0,0 +1,39 @@ +//go:build !linux + +package wgproxy + +import ( + "net" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" +) + +func seedProxies() ([]proxyInstance, error) { + // todo extend with Bind proxy + pl := make([]proxyInstance, 0) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + return pl, nil +} diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 76e5ed6f7..1aeab66b7 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -1,5 +1,3 @@ -//go:build linux - package wgproxy import ( @@ -7,12 +5,9 @@ import ( "io" "net" "os" - "runtime" "testing" "time" - "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" "github.com/netbirdio/netbird/util" ) @@ -22,6 +17,14 @@ func TestMain(m *testing.M) { os.Exit(code) } +type proxyInstance struct { + name string + proxy Proxy + wgPort int + endpointAddr *net.UDPAddr + closeFn func() error +} + type mocConn struct { closeChan chan struct{} closed bool @@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error { func TestProxyCloseByRemoteConn(t *testing.T) { ctx := context.Background() - tests := []struct { - name string - proxy Proxy - }{ - { - name: "userspace proxy", - proxy: udpProxy.NewWGUDPProxy(51830, 1280), - }, + tests, err := seedProxyForProxyCloseByRemoteConn() + if err != nil { + t.Fatalf("error: %v", err) } - if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { - ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) - if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) - } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy) - - tests = append(tests, struct { - name string - proxy Proxy - }{ - name: "ebpf proxy", - proxy: proxyWrapper, - }) - } + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + defer func() { + _ = relayedConn.Close() + }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892") relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) + err := tt.proxy.AddTurnConn(ctx, addr, relayedConn) if err != nil { t.Errorf("error: %v", err) } @@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) { }) } } + +// TestProxyRedirect todo extend the proxies with Bind proxy +func TestProxyRedirect(t *testing.T) { + tests, err := seedProxies() + if err != nil { + t.Fatalf("error: %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr) + if err := tt.closeFn(); err != nil { + t.Errorf("error: %v", err) + } + }) + } +} + +func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) { + t.Helper() + + msgHelloFromRelay := []byte("hello from relay") + msgRedirected := [][]byte{ + []byte("hello 1. to p2p"), + []byte("hello 2. to p2p"), + []byte("hello 3. to p2p"), + } + + dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: wgPort}) + if err != nil { + t.Fatalf("failed to listen on udp port: %s", err) + } + + relayedServer, _ := net.ListenUDP("udp", + &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + }, + ) + + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + + defer func() { + _ = dummyWgListener.Close() + _ = relayedConn.Close() + _ = relayedServer.Close() + }() + + if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil { + t.Errorf("error: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }() + + proxy.Work() + + if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil { + t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err) + } + + n, err := dummyWgListener.Read(make([]byte, 1024)) + if err != nil { + t.Errorf("error: %v", err) + } + + if n != len(msgHelloFromRelay) { + t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n) + } + + p2pEndpointAddr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 0, 56), + Port: 1234, + } + proxy.RedirectAs(p2pEndpointAddr) + + for _, msg := range msgRedirected { + if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil { + t.Errorf("error: %v", err) + } + } + + for i := 0; i < len(msgRedirected); i++ { + buf := make([]byte, 1024) + n, rAddr, err := dummyWgListener.ReadFrom(buf) + if err != nil { + t.Errorf("error: %v", err) + } + + if rAddr.String() != p2pEndpointAddr.String() { + t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String()) + } + if string(buf[:n]) != string(msgRedirected[i]) { + t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n])) + } + } +} diff --git a/client/iface/wgproxy/rawsocket/rawsocket.go b/client/iface/wgproxy/rawsocket/rawsocket.go new file mode 100644 index 000000000..a11ac46d5 --- /dev/null +++ b/client/iface/wgproxy/rawsocket/rawsocket.go @@ -0,0 +1,50 @@ +//go:build linux && !android + +package rawsocket + +import ( + "fmt" + "net" + "os" + "syscall" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +func PrepareSenderRawSocket() (net.PacketConn, error) { + // Create a raw socket. + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + if err != nil { + return nil, fmt.Errorf("creating raw socket failed: %w", err) + } + + // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) + if err != nil { + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + } + + // Bind the socket to the "lo" interface. + err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") + if err != nil { + return nil, fmt.Errorf("binding to lo interface failed: %w", err) + } + + // Set the fwmark on the socket. + err = nbnet.SetSocketOpt(fd) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + + // Convert the file descriptor to a PacketConn. + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) + if file == nil { + return nil, fmt.Errorf("converting fd to file failed") + } + packetConn, err := net.FilePacketConn(file) + if err != nil { + return nil, fmt.Errorf("converting file to packet conn failed: %w", err) + } + + return packetConn, nil +} diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index be65e2b27..4ef2f19c4 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -1,3 +1,5 @@ +//go:build linux && !android + package udp import ( @@ -21,16 +23,18 @@ type WGUDPProxy struct { localWGListenPort int mtu uint16 - remoteConn net.Conn - localConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + remoteConn net.Conn + localConn net.Conn + srcFakerConn *SrcFaker + sendPkg func(data []byte) (int, error) + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } @@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy { p := &WGUDPProxy{ localWGListenPort: wgPort, mtu: mtu, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } return p @@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem p.ctx, p.cancel = context.WithCancel(ctx) p.localConn = localConn + p.sendPkg = p.localConn.Write p.remoteConn = remoteConn return err @@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + p.sendPkg = p.localConn.Write + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } if !p.isStarted { p.isStarted = true go p.proxyToRemote(p.ctx) go p.proxyToLocal(p.ctx) } + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // Pause pauses the proxy from receiving data from the remote peer @@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +// RedirectAs start to use the fake sourced raw socket as package sender +func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + defer func() { + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + }() + + p.paused = false + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } + srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint) + if err != nil { + log.Errorf("failed to create src faker conn: %s", err) + // fallback to continue without redirecting + p.paused = true + return + } + p.srcFakerConn = srcFakerConn + p.sendPkg = p.srcFakerConn.SendPkg } // CloseConn close the localConn @@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error { } func (p *WGUDPProxy) close() error { + var result *multierror.Error + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error { p.cancel() - var result *multierror.Error + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) } @@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error { if err := p.localConn.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) } + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err)) + } + } + return cerrors.FormatErrorOrNil(result) } @@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - - _, err = p.localConn.Write(buf[:n]) - p.pausedMu.Unlock() + _, err = p.sendPkg(buf[:n]) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { diff --git a/client/iface/wgproxy/udp/rawsocket.go b/client/iface/wgproxy/udp/rawsocket.go new file mode 100644 index 000000000..fdc911463 --- /dev/null +++ b/client/iface/wgproxy/udp/rawsocket.go @@ -0,0 +1,101 @@ +//go:build linux && !android + +package udp + +import ( + "fmt" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" +) + +var ( + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + localHostNetIPAddr = &net.IPAddr{ + IP: net.ParseIP("127.0.0.1"), + } +) + +type SrcFaker struct { + srcAddr *net.UDPAddr + + rawSocket net.PacketConn + ipH gopacket.SerializableLayer + udpH gopacket.SerializableLayer + layerBuffer gopacket.SerializeBuffer +} + +func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { + rawSocket, err := rawsocket.PrepareSenderRawSocket() + if err != nil { + return nil, err + } + + ipH, udpH, err := prepareHeaders(dstPort, srcAddr) + if err != nil { + return nil, err + } + + f := &SrcFaker{ + srcAddr: srcAddr, + rawSocket: rawSocket, + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + } + + return f, nil +} + +func (f *SrcFaker) Close() error { + return f.rawSocket.Close() +} + +func (f *SrcFaker) SendPkg(data []byte) (int, error) { + defer func() { + if err := f.layerBuffer.Clear(); err != nil { + log.Errorf("failed to clear layer buffer: %s", err) + } + }() + + payload := gopacket.Payload(data) + + err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload) + if err != nil { + return 0, fmt.Errorf("serialize layers: %w", err) + } + n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr) + if err != nil { + return 0, fmt.Errorf("write to raw conn: %w", err) + } + return n, nil +} + +func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { + ipH := &layers.IPv4{ + DstIP: net.ParseIP("127.0.0.1"), + SrcIP: srcAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + udpH := &layers.UDP{ + SrcPort: layers.UDPPort(srcAddr.Port), + DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port + } + + err := udpH.SetNetworkLayerForChecksum(ipH) + if err != nil { + return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) + } + + return ipH, udpH, nil +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 86e4596d4..8db9e58f4 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -28,10 +28,6 @@ import ( semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) -const ( - defaultWgKeepAlive = 25 * time.Second -) - type ServiceDependencies struct { StatusRecorder *Status Signaler *Signaler @@ -117,6 +113,8 @@ type Conn struct { // debug purpose dumpState *stateDump + + endpointUpdater *EndpointUpdater } // NewConn creates a new not opened Conn to the remote peer. @@ -129,17 +127,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { connLog := log.WithField("peer", config.Key) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - semaphore: services.Semaphore, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + semaphore: services.Semaphore, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), } return conn, nil @@ -249,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgProxyICE = nil } - if err := conn.removeWgPeer(); err != nil { + if err := conn.endpointUpdater.RemoveWgPeer(); err != nil { conn.Log.Errorf("failed to remove wg endpoint: %v", err) } @@ -375,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn wgProxy.Work() } - if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { + conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) + presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) + if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() + if conn.wgProxyRelay != nil { + conn.Log.Debugf("redirect packets from relayed conn to WireGuard") + conn.wgProxyRelay.RedirectAs(ep) + } + conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) @@ -409,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() { conn.dumpState.SwitchToRelay() conn.wgProxyRelay.Work() - if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { + presharedKey := conn.presharedKey(conn.rosenpassRemoteKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil { conn.Log.Errorf("failed to switch to relay conn: %v", err) } @@ -418,6 +425,7 @@ func (conn *Conn) onICEStateDisconnected() { defer conn.wgWatcherWg.Done() conn.workerRelay.EnableWgWatcher(conn.ctx) }() + conn.wgProxyRelay.Work() conn.currentConnPriority = conntype.Relay } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) @@ -477,7 +485,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { + presharedKey := conn.presharedKey(rci.rosenpassPubKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) } @@ -545,17 +554,6 @@ func (conn *Conn) onGuardEvent() { } } -func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { - presharedKey := conn.presharedKey(remoteRPKey) - return conn.config.WgConfig.WgInterface.UpdatePeer( - conn.config.WgConfig.RemoteKey, - conn.config.WgConfig.AllowedIps, - defaultWgKeepAlive, - addr, - presharedKey, - ) -} - func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { peerState := State{ PubKey: conn.config.Key, @@ -698,10 +696,6 @@ func (conn *Conn) isICEActive() bool { return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected } -func (conn *Conn) removeWgPeer() error { - return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.Log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go new file mode 100644 index 000000000..39cb95591 --- /dev/null +++ b/client/internal/peer/endpoint.go @@ -0,0 +1,105 @@ +package peer + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const ( + defaultWgKeepAlive = 25 * time.Second + fallbackDelay = 5 * time.Second +) + +type EndpointUpdater struct { + log *logrus.Entry + wgConfig WgConfig + initiator bool + + // mu protects updateWireGuardPeer and cancelFunc + mu sync.Mutex + cancelFunc func() + updateWg sync.WaitGroup +} + +func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater { + return &EndpointUpdater{ + log: log, + wgConfig: wgConfig, + initiator: initiator, + } +} + +// ConfigureWGEndpoint sets up the WireGuard endpoint configuration. +// The initiator immediately configures the endpoint, while the non-initiator +// waits for a fallback period before configuring to avoid handshake congestion. +func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.initiator { + e.log.Debugf("configure up WireGuard as initiatr") + return e.updateWireGuardPeer(addr, presharedKey) + } + + // prevent to run new update while cancel the previous update + e.waitForCloseTheDelayedUpdate() + + var ctx context.Context + ctx, e.cancelFunc = context.WithCancel(context.Background()) + e.updateWg.Add(1) + go e.scheduleDelayedUpdate(ctx, addr, presharedKey) + + e.log.Debugf("configure up WireGuard and wait for handshake") + return e.updateWireGuardPeer(nil, presharedKey) +} + +func (e *EndpointUpdater) RemoveWgPeer() error { + e.mu.Lock() + defer e.mu.Unlock() + + e.waitForCloseTheDelayedUpdate() + return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) +} + +func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() { + if e.cancelFunc == nil { + return + } + + e.cancelFunc() + e.cancelFunc = nil + e.updateWg.Wait() +} + +// scheduleDelayedUpdate waits for the fallback period before updating the endpoint +func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) { + defer e.updateWg.Done() + t := time.NewTimer(fallbackDelay) + defer t.Stop() + + select { + case <-ctx.Done(): + return + case <-t.C: + e.mu.Lock() + if err := e.updateWireGuardPeer(addr, presharedKey); err != nil { + e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) + } + e.mu.Unlock() + } +} + +func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error { + return e.wgConfig.WgInterface.UpdatePeer( + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + endpoint, + presharedKey, + ) +} From 5e1a40c33fdd83427d9c5e4389e3338ffdbdf3bd Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 30 Sep 2025 23:40:46 +0200 Subject: [PATCH 004/120] [client] Order the list of candidates for proper comparison (#4561) Order the list of candidates for proper comparison --- client/internal/peer/guard/ice_monitor.go | 25 +++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 70850e6eb..09cf9ae63 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -3,6 +3,8 @@ package guard import ( "context" "fmt" + "slices" + "sort" "sync" "time" @@ -24,8 +26,8 @@ type ICEMonitor struct { iFaceDiscover stdnet.ExternalIFaceDiscover iceConfig icemaker.Config - currentCandidates []ice.Candidate - candidatesMu sync.Mutex + currentCandidatesAddress []string + candidatesMu sync.Mutex } func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { @@ -115,16 +117,21 @@ func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool { cm.candidatesMu.Lock() defer cm.candidatesMu.Unlock() - if len(cm.currentCandidates) != len(newCandidates) { - cm.currentCandidates = newCandidates + newAddresses := make([]string, len(newCandidates)) + for i, c := range newCandidates { + newAddresses[i] = c.Address() + } + sort.Strings(newAddresses) + + if len(cm.currentCandidatesAddress) != len(newAddresses) { + cm.currentCandidatesAddress = newAddresses return true } - for i, candidate := range cm.currentCandidates { - if candidate.Address() != newCandidates[i].Address() { - cm.currentCandidates = newCandidates - return true - } + // Compare elements + if !slices.Equal(cm.currentCandidatesAddress, newAddresses) { + cm.currentCandidatesAddress = newAddresses + return true } return false From b5daec3b51ee01ea779f727c74a0aa394a7a3d5d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:10:11 +0200 Subject: [PATCH 005/120] [client,signal,management] Add browser client support (#4415) --- .github/workflows/golangci-lint.yml | 2 +- .github/workflows/wasm-build-validation.yml | 67 +++++ .gitmodules | 0 .goreleaser.yaml | 17 ++ client/cmd/debug_js.go | 8 + client/cmd/testutil_test.go | 4 +- client/embed/embed.go | 60 +++- client/grpc/dialer.go | 42 +-- client/grpc/dialer_generic.go | 44 +++ client/grpc/dialer_js.go | 12 + client/iface/bind/error.go | 7 + client/iface/bind/ice_bind.go | 58 ++-- client/iface/bind/recv_msg.go | 6 + client/iface/bind/relay_bind.go | 125 ++++++++ client/iface/configurer/name.go | 2 +- client/iface/configurer/uapi.go | 2 +- client/iface/configurer/uapi_js.go | 23 ++ client/iface/device/device_netstack.go | 28 +- client/iface/device/device_netstack_test.go | 27 ++ client/iface/iface_destroy_js.go | 6 + client/iface/iface_new_android.go | 4 +- client/iface/iface_new_darwin.go | 2 +- client/iface/iface_new_freebsd.go | 4 +- client/iface/iface_new_ios.go | 2 +- client/iface/iface_new_js.go | 27 ++ client/iface/iface_new_linux.go | 4 +- client/iface/iface_new_windows.go | 2 +- client/iface/netstack/env.go | 2 + client/iface/netstack/env_js.go | 12 + client/iface/wgproxy/bind/proxy.go | 23 +- client/iface/wgproxy/factory_usp.go | 11 +- client/iface/wgproxy/proxy_linux_test.go | 2 +- client/iface/wgproxy/proxy_seed_test.go | 2 +- client/internal/dns/server_js.go | 5 + client/internal/dns/unclean_shutdown_js.go | 19 ++ client/internal/engine.go | 20 +- client/internal/engine_generic.go | 19 ++ client/internal/engine_js.go | 18 ++ client/internal/engine_test.go | 8 +- .../networkmonitor/check_change_js.go | 12 + .../routemanager/systemops/systemops_js.go | 48 ++++ .../systemops/systemops_nonlinux.go | 2 +- client/server/server_test.go | 17 +- client/ssh/client.go | 2 + client/ssh/login.go | 2 + client/ssh/server.go | 2 + client/ssh/server_mock.go | 2 + client/ssh/server_test.go | 2 + client/ssh/ssh_js.go | 137 +++++++++ client/ssh/util.go | 2 + client/system/info_js.go | 231 +++++++++++++++ client/wasm/cmd/main.go | 245 ++++++++++++++++ client/wasm/internal/http/http.go | 100 +++++++ client/wasm/internal/rdp/cert_validation.go | 96 +++++++ client/wasm/internal/rdp/rdcleanpath.go | 271 ++++++++++++++++++ .../wasm/internal/rdp/rdcleanpath_handlers.go | 251 ++++++++++++++++ client/wasm/internal/ssh/client.go | 213 ++++++++++++++ client/wasm/internal/ssh/handlers.go | 78 +++++ client/wasm/internal/ssh/key.go | 50 ++++ encryption/route53.go | 2 + flow/client/client.go | 5 +- go.mod | 2 +- go.sum | 4 +- management/internals/server/controllers.go | 8 +- management/internals/server/modules.go | 4 + management/internals/server/server.go | 32 ++- management/server/account.go | 6 + management/server/account/manager.go | 4 +- management/server/account_test.go | 38 +-- management/server/dns_test.go | 4 +- management/server/grpcserver.go | 5 +- .../http/handlers/peers/peers_handler.go | 83 ++++++ management/server/management_proto_test.go | 3 +- management/server/management_test.go | 3 +- management/server/mock_server/account_mock.go | 12 +- management/server/nameserver_test.go | 4 +- .../server/networks/resources/manager.go | 4 +- management/server/peer.go | 58 +++- management/server/peer/peer.go | 12 + management/server/peer_test.go | 74 ++--- .../server/peers/ephemeral/interface.go | 14 + .../ephemeral/manager}/ephemeral.go | 2 +- .../ephemeral/manager}/ephemeral_test.go | 59 +++- management/server/policy.go | 12 + management/server/store/sql_store.go | 19 ++ management/server/store/store.go | 1 + management/server/types/account.go | 32 ++- management/server/types/policy.go | 87 ++++++ management/server/types/resource.go | 13 +- management/server/user_test.go | 4 +- shared/management/client/client_test.go | 4 +- shared/management/http/api/openapi.yml | 81 +++++- shared/management/http/api/types.gen.go | 28 ++ shared/management/proto/management.pb.go | 2 +- shared/relay/client/client.go | 12 +- shared/relay/client/dialer/ws/conn.go | 3 +- .../client/dialer/ws/dialopts_generic.go | 11 + shared/relay/client/dialer/ws/dialopts_js.go | 10 + shared/relay/client/dialer/ws/ws.go | 4 +- shared/relay/client/dialers_generic.go | 19 ++ shared/relay/client/dialers_js.go | 13 + signal/cmd/run.go | 52 +++- util/util_js.go | 8 + util/wsproxy/client/dialer_js.go | 171 +++++++++++ util/wsproxy/constants.go | 13 + util/wsproxy/server/metrics.go | 118 ++++++++ util/wsproxy/server/proxy.go | 227 +++++++++++++++ 107 files changed, 3591 insertions(+), 284 deletions(-) create mode 100644 .github/workflows/wasm-build-validation.yml create mode 100644 .gitmodules create mode 100644 client/cmd/debug_js.go create mode 100644 client/grpc/dialer_generic.go create mode 100644 client/grpc/dialer_js.go create mode 100644 client/iface/bind/error.go create mode 100644 client/iface/bind/recv_msg.go create mode 100644 client/iface/bind/relay_bind.go create mode 100644 client/iface/configurer/uapi_js.go create mode 100644 client/iface/device/device_netstack_test.go create mode 100644 client/iface/iface_destroy_js.go create mode 100644 client/iface/iface_new_js.go create mode 100644 client/iface/netstack/env_js.go create mode 100644 client/internal/dns/server_js.go create mode 100644 client/internal/dns/unclean_shutdown_js.go create mode 100644 client/internal/engine_generic.go create mode 100644 client/internal/engine_js.go create mode 100644 client/internal/networkmonitor/check_change_js.go create mode 100644 client/internal/routemanager/systemops/systemops_js.go create mode 100644 client/ssh/ssh_js.go create mode 100644 client/system/info_js.go create mode 100644 client/wasm/cmd/main.go create mode 100644 client/wasm/internal/http/http.go create mode 100644 client/wasm/internal/rdp/cert_validation.go create mode 100644 client/wasm/internal/rdp/rdcleanpath.go create mode 100644 client/wasm/internal/rdp/rdcleanpath_handlers.go create mode 100644 client/wasm/internal/ssh/client.go create mode 100644 client/wasm/internal/ssh/handlers.go create mode 100644 client/wasm/internal/ssh/key.go create mode 100644 management/server/peers/ephemeral/interface.go rename management/server/{ => peers/ephemeral/manager}/ephemeral.go (99%) rename management/server/{ => peers/ephemeral/manager}/ephemeral_test.go (75%) create mode 100644 shared/relay/client/dialer/ws/dialopts_generic.go create mode 100644 shared/relay/client/dialer/ws/dialopts_js.go create mode 100644 shared/relay/client/dialers_generic.go create mode 100644 shared/relay/client/dialers_js.go create mode 100644 util/util_js.go create mode 100644 util/wsproxy/client/dialer_js.go create mode 100644 util/wsproxy/constants.go create mode 100644 util/wsproxy/server/metrics.go create mode 100644 util/wsproxy/server/proxy.go diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 7e6583cc6..2845b05a5 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros skip: go.mod,go.sum golangci: strategy: diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml new file mode 100644 index 000000000..e4ac799bc --- /dev/null +++ b/.github/workflows/wasm-build-validation.yml @@ -0,0 +1,67 @@ +name: Wasm + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} + cancel-in-progress: true + +jobs: + js_lint: + name: "JS / Lint" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev + - name: Install golangci-lint + uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc + with: + version: latest + install-mode: binary + skip-cache: true + skip-pkg-cache: true + skip-build-cache: true + - name: Run golangci-lint for WASM + run: | + GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/... + continue-on-error: true + + js_build: + name: "JS / Build" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + - name: Build Wasm client + run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd + env: + CGO_ENABLED: 0 + - name: Check Wasm build size + run: | + echo "Wasm build size:" + ls -lh netbird.wasm + + SIZE=$(stat -c%s netbird.wasm) + SIZE_MB=$((SIZE / 1024 / 1024)) + + echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" + + if [ ${SIZE} -gt 52428800 ]; then + echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!" + exit 1 + fi + diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..e69de29bb diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 59a95c89a..952e946dc 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -2,6 +2,18 @@ version: 2 project_name: netbird builds: + - id: netbird-wasm + dir: client/wasm/cmd + binary: netbird + env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0] + goos: + - js + goarch: + - wasm + ldflags: + - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser + mod_timestamp: "{{ .CommitTimestamp }}" + - id: netbird dir: client binary: netbird @@ -115,6 +127,11 @@ archives: - builds: - netbird - netbird-static + - id: netbird-wasm + builds: + - netbird-wasm + name_template: "{{ .ProjectName }}_{{ .Version }}" + format: binary nfpms: - maintainer: Netbird diff --git a/client/cmd/debug_js.go b/client/cmd/debug_js.go new file mode 100644 index 000000000..d06fb8efc --- /dev/null +++ b/client/cmd/debug_js.go @@ -0,0 +1,8 @@ +package cmd + +import "context" + +// SetupDebugHandler is a no-op for WASM +func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) { + // Debug handler not needed for WASM +} diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 99ccb1539..bd3209605 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/management-integrations/integrations" + clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/management/internals/server/config" @@ -20,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index 0bfc7a37c..e918235ed 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -23,23 +23,29 @@ import ( var ErrClientAlreadyStarted = errors.New("client already started") var ErrClientNotStarted = errors.New("client not started") +var ErrConfigNotInitialized = errors.New("config not initialized") -// Client manages a netbird embedded client instance +// Client manages a netbird embedded client instance. type Client struct { deviceName string config *profilemanager.Config mu sync.Mutex cancel context.CancelFunc setupKey string + jwtToken string connect *internal.ConnectClient } -// Options configures a new Client +// Options configures a new Client. type Options struct { // DeviceName is this peer's name in the network DeviceName string // SetupKey is used for authentication SetupKey string + // JWTToken is used for JWT-based authentication + JWTToken string + // PrivateKey is used for direct private key authentication + PrivateKey string // ManagementURL overrides the default management server URL ManagementURL string // PreSharedKey is the pre-shared key for the WireGuard interface @@ -58,8 +64,35 @@ type Options struct { DisableClientRoutes bool } -// New creates a new netbird embedded client +// validateCredentials checks that exactly one credential type is provided +func (opts *Options) validateCredentials() error { + credentialsProvided := 0 + if opts.SetupKey != "" { + credentialsProvided++ + } + if opts.JWTToken != "" { + credentialsProvided++ + } + if opts.PrivateKey != "" { + credentialsProvided++ + } + + if credentialsProvided == 0 { + return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided") + } + if credentialsProvided > 1 { + return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified") + } + + return nil +} + +// New creates a new netbird embedded client. func New(opts Options) (*Client, error) { + if err := opts.validateCredentials(); err != nil { + return nil, err + } + if opts.LogOutput != nil { logrus.SetOutput(opts.LogOutput) } @@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) { return nil, fmt.Errorf("create config: %w", err) } + if opts.PrivateKey != "" { + config.PrivateKey = opts.PrivateKey + } + return &Client{ deviceName: opts.DeviceName, setupKey: opts.SetupKey, + jwtToken: opts.JWTToken, config: config, }, nil } @@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error { ctx := internal.CtxInitState(context.Background()) // nolint:staticcheck ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) - if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil { + if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } @@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error { } } +// GetConfig returns a copy of the internal client config. +func (c *Client) GetConfig() (profilemanager.Config, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.config == nil { + return profilemanager.Config{}, ErrConfigNotInitialized + } + return *c.config, nil +} + // Dial dials a network address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { @@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e return nsnet.DialContext(ctx, network, address) } -// ListenTCP listens on the given address in the netbird network +// ListenTCP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenTCP(address string) (net.Listener, error) { nsnet, addr, err := c.getNet() @@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) { return nsnet.ListenTCP(tcpAddr) } -// ListenUDP listens on the given address in the netbird network +// ListenUDP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenUDP(address string) (net.PacketConn, error) { nsnet, addr, err := c.getNet() diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 69e3f088c..7cb38fbff 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,15 +4,9 @@ import ( "context" "crypto/tls" "crypto/x509" - "fmt" - "net" - "os/user" "runtime" "time" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -20,37 +14,10 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - nbnet "github.com/netbirdio/netbird/client/net" - "github.com/netbirdio/netbird/util/embeddedroots" ) -func WithCustomDialer() grpc.DialOption { - return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - if runtime.GOOS == "linux" { - currentUser, err := user.Current() - if err != nil { - return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) - } - - // the custom dialer requires root permissions which are not required for use cases run as non-root - if currentUser.Uid != "0" { - log.Debug("Not running as root, using standard dialer") - dialer := &net.Dialer{} - return dialer.DialContext(ctx, "tcp", addr) - } - } - - conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) - if err != nil { - log.Errorf("Failed to dial: %s", err) - return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) - } - return conn, nil - }) -} - -// grpcDialBackoff is the backoff mechanism for the grpc calls +// Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() b.MaxElapsedTime = 10 * time.Second @@ -58,6 +25,7 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } +// CreateConnection creates a gRPC client connection with the appropriate transport options func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { @@ -68,7 +36,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. } transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ - RootCAs: certPool, + // for js, outer websocket layer takes care of tls verification via WithCustomDialer + InsecureSkipVerify: runtime.GOOS == "js", + RootCAs: certPool, })) } @@ -79,7 +49,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. connCtx, addr, transportOption, - WithCustomDialer(), + WithCustomDialer(tlsEnabled), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go new file mode 100644 index 000000000..a0d6cee0b --- /dev/null +++ b/client/grpc/dialer_generic.go @@ -0,0 +1,44 @@ +//go:build !js + +package grpc + +import ( + "context" + "fmt" + "net" + "os/user" + "runtime" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +func WithCustomDialer(tlsEnabled bool) grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + if runtime.GOOS == "linux" { + currentUser, err := user.Current() + if err != nil { + return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) + } + + // the custom dialer requires root permissions which are not required for use cases run as non-root + if currentUser.Uid != "0" { + log.Debug("Not running as root, using standard dialer") + dialer := &net.Dialer{} + return dialer.DialContext(ctx, "tcp", addr) + } + } + + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) + } + return conn, nil + }) +} diff --git a/client/grpc/dialer_js.go b/client/grpc/dialer_js.go new file mode 100644 index 000000000..e132c0098 --- /dev/null +++ b/client/grpc/dialer_js.go @@ -0,0 +1,12 @@ +package grpc + +import ( + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/util/wsproxy/client" +) + +// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments. +func WithCustomDialer(tlsEnabled bool) grpc.DialOption { + return client.WithWebSocketDialer(tlsEnabled) +} diff --git a/client/iface/bind/error.go b/client/iface/bind/error.go new file mode 100644 index 000000000..db7c23144 --- /dev/null +++ b/client/iface/bind/error.go @@ -0,0 +1,7 @@ +package bind + +import "fmt" + +var ( + ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM") +) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index ef630b9d0..dfb22ecde 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,3 +1,5 @@ +//go:build !js + package bind import ( @@ -21,11 +23,6 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -type RecvMessage struct { - Endpoint *Endpoint - Buffer []byte -} - type receiverCreator struct { iceBind *ICEBind } @@ -43,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD // use the port because in the Send function the wgConn.Endpoint the port info is not exported. type ICEBind struct { *wgConn.StdNetBind - recvChan chan RecvMessage transportNet transport.Net filterFn udpmux.FilterFn - endpoints map[netip.Addr]net.Conn - endpointsMu sync.Mutex + address wgaddr.Address + mtu uint16 + + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + recvChan chan recvMessage // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a // new closed channel. With the closedChanMu we can safely close the channel and create a new one - closedChan chan struct{} - closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. - closed bool - - muUDPMux sync.Mutex - udpMux *udpmux.UniversalUDPMuxDefault - address wgaddr.Address - mtu uint16 + closedChan chan struct{} + closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. + closed bool activityRecorder *ActivityRecorder + + muUDPMux sync.Mutex + udpMux *udpmux.UniversalUDPMuxDefault } func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, - recvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, + address: address, + mtu: mtu, endpoints: make(map[netip.Addr]net.Conn), + recvChan: make(chan recvMessage, 1), closedChan: make(chan struct{}), closed: true, - mtu: mtu, - address: address, activityRecorder: NewActivityRecorder(), } @@ -84,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg return ib } -func (s *ICEBind) MTU() uint16 { - return s.mtu -} - func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { s.closed = false s.closedChanMu.Lock() @@ -140,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) { delete(b.endpoints, fakeIP) } +func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) { + select { + case <-b.closedChan: + return + case <-ctx.Done(): + return + case b.recvChan <- recvMessage{ep, buf}: + } +} + func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { b.endpointsMu.Lock() conn, ok := b.endpoints[ep.DstIP()] @@ -156,14 +160,6 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } -func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) { - select { - case <-ctx.Done(): - return - case b.recvChan <- msg: - } -} - func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() diff --git a/client/iface/bind/recv_msg.go b/client/iface/bind/recv_msg.go new file mode 100644 index 000000000..65baffaac --- /dev/null +++ b/client/iface/bind/recv_msg.go @@ -0,0 +1,6 @@ +package bind + +type recvMessage struct { + Endpoint *Endpoint + Buffer []byte +} diff --git a/client/iface/bind/relay_bind.go b/client/iface/bind/relay_bind.go new file mode 100644 index 000000000..4c179d6a5 --- /dev/null +++ b/client/iface/bind/relay_bind.go @@ -0,0 +1,125 @@ +package bind + +import ( + "context" + "net" + "net/netip" + "sync" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/conn" + + "github.com/netbirdio/netbird/client/iface/udpmux" +) + +// RelayBindJS is a conn.Bind implementation for WebAssembly environments. +// Do not limit to build only js, because we want to be able to run tests +type RelayBindJS struct { + *conn.StdNetBind + + recvChan chan recvMessage + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + activityRecorder *ActivityRecorder + ctx context.Context + cancel context.CancelFunc +} + +func NewRelayBindJS() *RelayBindJS { + return &RelayBindJS{ + recvChan: make(chan recvMessage, 100), + endpoints: make(map[netip.Addr]net.Conn), + activityRecorder: NewActivityRecorder(), + } +} + +// Open creates a receive function for handling relay packets in WASM. +func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { + log.Debugf("Open: creating receive function for port %d", uport) + + s.ctx, s.cancel = context.WithCancel(context.Background()) + + receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { + select { + case <-s.ctx.Done(): + return 0, net.ErrClosed + case msg, ok := <-s.recvChan: + if !ok { + return 0, net.ErrClosed + } + copy(bufs[0], msg.Buffer) + sizes[0] = len(msg.Buffer) + eps[0] = conn.Endpoint(msg.Endpoint) + return 1, nil + } + } + + log.Debugf("Open: receive function created, returning port %d", uport) + return []conn.ReceiveFunc{receiveFn}, uport, nil +} + +func (s *RelayBindJS) Close() error { + if s.cancel == nil { + return nil + } + log.Debugf("close RelayBindJS") + s.cancel() + return nil +} + +func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) { + select { + case <-s.ctx.Done(): + return + case <-ctx.Done(): + return + case s.recvChan <- recvMessage{ep, buf}: + } +} + +// Send forwards packets through the relay connection for WASM. +func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error { + if ep == nil { + return nil + } + + fakeIP := ep.DstIP() + + s.endpointsMu.Lock() + relayConn, ok := s.endpoints[fakeIP] + s.endpointsMu.Unlock() + + if !ok { + return nil + } + + for _, buf := range bufs { + if _, err := relayConn.Write(buf); err != nil { + return err + } + } + + return nil +} + +func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + b.endpointsMu.Lock() + b.endpoints[fakeIP] = conn + b.endpointsMu.Unlock() +} + +func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) { + s.endpointsMu.Lock() + defer s.endpointsMu.Unlock() + + delete(s.endpoints, fakeIP) +} + +// GetICEMux returns the ICE UDPMux that was created and used by ICEBind +func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { + return nil, ErrUDPMUXNotSupported +} + +func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder { + return s.activityRecorder +} diff --git a/client/iface/configurer/name.go b/client/iface/configurer/name.go index 3b9abc0e8..a8469e0b4 100644 --- a/client/iface/configurer/name.go +++ b/client/iface/configurer/name.go @@ -1,4 +1,4 @@ -//go:build linux || windows || freebsd +//go:build linux || windows || freebsd || js || wasip1 package configurer diff --git a/client/iface/configurer/uapi.go b/client/iface/configurer/uapi.go index 4801841de..f85c7852a 100644 --- a/client/iface/configurer/uapi.go +++ b/client/iface/configurer/uapi.go @@ -1,4 +1,4 @@ -//go:build !windows +//go:build !windows && !js package configurer diff --git a/client/iface/configurer/uapi_js.go b/client/iface/configurer/uapi_js.go new file mode 100644 index 000000000..d0188eb35 --- /dev/null +++ b/client/iface/configurer/uapi_js.go @@ -0,0 +1,23 @@ +package configurer + +import ( + "net" +) + +type noopListener struct{} + +func (n *noopListener) Accept() (net.Conn, error) { + return nil, net.ErrClosed +} + +func (n *noopListener) Close() error { + return nil +} + +func (n *noopListener) Addr() net.Addr { + return nil +} + +func openUAPI(deviceName string) (net.Listener, error) { + return &noopListener{}, nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index a6ef47027..e37321b68 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,9 +1,11 @@ package device import ( + "errors" "fmt" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -15,6 +17,12 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +type Bind interface { + conn.Bind + GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) + ActivityRecorder() *bind.ActivityRecorder +} + type TunNetstackDevice struct { name string address wgaddr.Address @@ -22,7 +30,7 @@ type TunNetstackDevice struct { key string mtu uint16 listenAddress string - iceBind *bind.ICEBind + bind Bind device *device.Device filteredDevice *FilteredDevice @@ -33,7 +41,7 @@ type TunNetstackDevice struct { net *netstack.Net } -func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { +func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -41,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri key: key, mtu: mtu, listenAddress: listenAddress, - iceBind: iceBind, + bind: bind, } } @@ -66,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { t.device = device.NewDevice( t.filteredDevice, - t.iceBind, + t.bind, device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { _ = tunIface.Close() @@ -91,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return nil, err } - udpMux, err := t.iceBind.GetICEMux() - if err != nil { + udpMux, err := t.bind.GetICEMux() + if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) { return nil, err } - t.udpMux = udpMux + + if udpMux != nil { + t.udpMux = udpMux + } + log.Debugf("netstack device is ready to use") return udpMux, nil } diff --git a/client/iface/device/device_netstack_test.go b/client/iface/device/device_netstack_test.go new file mode 100644 index 000000000..52059602f --- /dev/null +++ b/client/iface/device/device_netstack_test.go @@ -0,0 +1,27 @@ +package device + +import ( + "testing" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func TestNewNetstackDevice(t *testing.T) { + privateKey, _ := wgtypes.GeneratePrivateKey() + wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24") + + relayBind := bind.NewRelayBindJS() + nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr()) + + cfgr, err := nsTun.Create() + if err != nil { + t.Fatalf("failed to create netstack device: %v", err) + } + if cfgr == nil { + t.Fatal("expected non-nil configurer") + } +} diff --git a/client/iface/iface_destroy_js.go b/client/iface/iface_destroy_js.go new file mode 100644 index 000000000..b443273c3 --- /dev/null +++ b/client/iface/iface_destroy_js.go @@ -0,0 +1,6 @@ +package iface + +// Destroy is a no-op on WASM +func (w *WGIface) Destroy() error { + return nil +} diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 26952f48d..3b68f63f2 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } @@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go index 7dd74d571..9f21ec950 100644 --- a/client/iface/iface_new_darwin.go +++ b/client/iface/iface_new_darwin.go @@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go index 86ed14ce1..a342bd579 100644 --- a/client/iface/iface_new_freebsd.go +++ b/client/iface/iface_new_freebsd.go @@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } @@ -33,7 +33,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go index 06ccf0be1..5d6a32e39 100644 --- a/client/iface/iface_new_ios.go +++ b/client/iface/iface_new_ios.go @@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), userspaceBind: true, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_js.go b/client/iface/iface_new_js.go new file mode 100644 index 000000000..ad913ab04 --- /dev/null +++ b/client/iface/iface_new_js.go @@ -0,0 +1,27 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode) +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + relayBind := bind.NewRelayBindJS() + + wgIface := &WGIface{ + tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU), + } + + return wgIface, nil +} diff --git a/client/iface/iface_new_linux.go b/client/iface/iface_new_linux.go index 77fd30fae..d84035403 100644 --- a/client/iface/iface_new_linux.go +++ b/client/iface/iface_new_linux.go @@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } @@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new_windows.go index 349c5b33b..dfd9028e7 100644 --- a/client/iface/iface_new_windows.go +++ b/client/iface/iface_new_windows.go @@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil diff --git a/client/iface/netstack/env.go b/client/iface/netstack/env.go index cdbf975b1..dd8cf29a3 100644 --- a/client/iface/netstack/env.go +++ b/client/iface/netstack/env.go @@ -1,3 +1,5 @@ +//go:build !js + package netstack import ( diff --git a/client/iface/netstack/env_js.go b/client/iface/netstack/env_js.go new file mode 100644 index 000000000..05c20f036 --- /dev/null +++ b/client/iface/netstack/env_js.go @@ -0,0 +1,12 @@ +package netstack + +const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE" + +// IsEnabled always returns true for js since it's the only mode available +func IsEnabled() bool { + return true +} + +func ListenAddr() string { + return "" +} diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index dbc694e91..eb585d8a2 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -16,15 +16,14 @@ import ( "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) -type IceBind interface { - SetEndpoint(fakeIP netip.Addr, conn net.Conn) - RemoveEndpoint(fakeIP netip.Addr) - Recv(ctx context.Context, msg bind.RecvMessage) - MTU() uint16 +type Bind interface { + SetEndpoint(addr netip.Addr, conn net.Conn) + RemoveEndpoint(addr netip.Addr) + ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte) } type ProxyBind struct { - bind IceBind + bind Bind // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address wgRelayedEndpoint *bind.Endpoint @@ -40,13 +39,15 @@ type ProxyBind struct { isStarted bool closeListener *listener.CloseListener + mtu uint16 } -func NewProxyBind(bind IceBind) *ProxyBind { +func NewProxyBind(bind Bind, mtu uint16) *ProxyBind { p := &ProxyBind{ bind: bind, closeListener: listener.NewCloseListener(), pausedCond: sync.NewCond(&sync.Mutex{}), + mtu: mtu + bufsize.WGBufferOverhead, } return p @@ -174,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { }() for { - buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead) + buf := make([]byte, p.mtu) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { @@ -190,11 +191,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { p.pausedCond.Wait() } - msg := bind.RecvMessage{ - Endpoint: p.wgCurrentUsed, - Buffer: buf[:n], - } - p.bind.Recv(ctx, msg) + p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n]) p.pausedCond.L.Unlock() } } diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index 141b4c1f9..a1b1c34d7 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -3,24 +3,25 @@ package wgproxy import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" ) type USPFactory struct { - bind *bind.ICEBind + bind proxyBind.Bind + mtu uint16 } -func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { +func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory { log.Infof("WireGuard Proxy Factory will produce bind proxy") f := &USPFactory{ - bind: iceBind, + bind: bind, + mtu: mtu, } return f } func (w *USPFactory) GetProxy() Proxy { - return proxyBind.NewProxyBind(w.bind) + return proxyBind.NewProxyBind(w.bind, w.mtu) } func (w *USPFactory) Free() error { diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go index 9526e91d2..dd24d1cdc 100644 --- a/client/iface/wgproxy/proxy_linux_test.go +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -74,7 +74,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { pBind := proxyInstance{ name: "bind proxy", - proxy: bindproxy.NewProxyBind(iceBind), + proxy: bindproxy.NewProxyBind(iceBind, 0), endpointAddr: endpointAddress, closeFn: func() error { return nil }, } diff --git a/client/iface/wgproxy/proxy_seed_test.go b/client/iface/wgproxy/proxy_seed_test.go index 4d244f18a..ad375ccde 100644 --- a/client/iface/wgproxy/proxy_seed_test.go +++ b/client/iface/wgproxy/proxy_seed_test.go @@ -30,7 +30,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { pBind := proxyInstance{ name: "bind proxy", - proxy: bindproxy.NewProxyBind(iceBind), + proxy: bindproxy.NewProxyBind(iceBind, 0), endpointAddr: endpointAddress, closeFn: func() error { return nil }, } diff --git a/client/internal/dns/server_js.go b/client/internal/dns/server_js.go new file mode 100644 index 000000000..a8bc35d09 --- /dev/null +++ b/client/internal/dns/server_js.go @@ -0,0 +1,5 @@ +package dns + +func (s *DefaultServer) initialize() (hostManager, error) { + return &noopHostConfigurator{}, nil +} diff --git a/client/internal/dns/unclean_shutdown_js.go b/client/internal/dns/unclean_shutdown_js.go new file mode 100644 index 000000000..378ffc164 --- /dev/null +++ b/client/internal/dns/unclean_shutdown_js.go @@ -0,0 +1,19 @@ +package dns + +import ( + "context" +) + +type ShutdownState struct{} + +func (s *ShutdownState) Name() string { + return "dns_state" +} + +func (s *ShutdownState) Cleanup() error { + return nil +} + +func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error { + return nil +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 828bc6e94..3fa0b58a8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -453,8 +453,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("up wg interface: %w", err) } - - // if inbound conns are blocked there is no need to create the ACL manager if e.firewall != nil && !e.config.BlockInbound { e.acl = acl.NewDefaultManager(e.firewall) @@ -466,14 +464,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("initialize dns server: %w", err) } - iceCfg := icemaker.Config{ - StunTurn: &e.stunTurn, - InterfaceBlackList: e.config.IFaceBlackList, - DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.SingleSocketUDPMux, - UDPMuxSrflx: e.udpMux, - NATExternalIPs: e.parseNATExternalIPMappings(), - } + iceCfg := e.createICEConfig() e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface) e.connMgr.Start(e.ctx) @@ -1347,14 +1338,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV Addr: e.getRosenpassAddr(), PermissiveMode: e.config.RosenpassPermissive, }, - ICEConfig: icemaker.Config{ - StunTurn: &e.stunTurn, - InterfaceBlackList: e.config.IFaceBlackList, - DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.SingleSocketUDPMux, - UDPMuxSrflx: e.udpMux, - NATExternalIPs: e.parseNATExternalIPMappings(), - }, + ICEConfig: e.createICEConfig(), } serviceDependencies := peer.ServiceDependencies{ diff --git a/client/internal/engine_generic.go b/client/internal/engine_generic.go new file mode 100644 index 000000000..34a75e45b --- /dev/null +++ b/client/internal/engine_generic.go @@ -0,0 +1,19 @@ +//go:build !js + +package internal + +import ( + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" +) + +// createICEConfig creates ICE configuration for non-WASM environments +func (e *Engine) createICEConfig() icemaker.Config { + return icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + UDPMux: e.udpMux.SingleSocketUDPMux, + UDPMuxSrflx: e.udpMux, + NATExternalIPs: e.parseNATExternalIPMappings(), + } +} diff --git a/client/internal/engine_js.go b/client/internal/engine_js.go new file mode 100644 index 000000000..dce3c57fb --- /dev/null +++ b/client/internal/engine_js.go @@ -0,0 +1,18 @@ +//go:build js + +package internal + +import ( + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" +) + +// createICEConfig creates ICE configuration for WASM environment. +func (e *Engine) createICEConfig() icemaker.Config { + cfg := icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + NATExternalIPs: e.parseNATExternalIPMappings(), + } + return cfg +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 4d2e81f43..344104405 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -27,6 +27,10 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" @@ -42,10 +46,8 @@ import ( "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" @@ -1584,7 +1586,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/client/internal/networkmonitor/check_change_js.go b/client/internal/networkmonitor/check_change_js.go new file mode 100644 index 000000000..640cf7184 --- /dev/null +++ b/client/internal/networkmonitor/check_change_js.go @@ -0,0 +1,12 @@ +package networkmonitor + +import ( + "context" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + // No-op for WASM - network changes don't apply + return nil +} diff --git a/client/internal/routemanager/systemops/systemops_js.go b/client/internal/routemanager/systemops/systemops_js.go new file mode 100644 index 000000000..808507fc9 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_js.go @@ -0,0 +1,48 @@ +package systemops + +import ( + "errors" + "net" + "net/netip" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +var ErrRouteNotSupported = errors.New("route operations not supported on js") + +func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return ErrRouteNotSupported +} + +func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return ErrRouteNotSupported +} + +func GetRoutesFromTable() ([]netip.Prefix, error) { + return []netip.Prefix{}, nil +} + +func hasSeparateRouting() ([]netip.Prefix, error) { + return []netip.Prefix{}, nil +} + +// GetDetailedRoutesFromTable returns empty routes for WASM. +func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { + return []DetailedRoute{}, nil +} + +func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return ErrRouteNotSupported +} + +func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return ErrRouteNotSupported +} + +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, _ bool) error { + return nil +} + +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, _ bool) error { + return nil +} diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 83b64e82b..905a7bc12 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -1,4 +1,4 @@ -//go:build !linux && !ios +//go:build !linux && !ios && !js package systemops diff --git a/client/server/server_test.go b/client/server/server_test.go index 755925003..e0a4805f6 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -10,23 +10,26 @@ import ( "time" "github.com/golang/mock/gomock" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" daemonProto "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" @@ -314,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/client/ssh/client.go b/client/ssh/client.go index 2dc70e8fc..afba347f8 100644 --- a/client/ssh/client.go +++ b/client/ssh/client.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/login.go b/client/ssh/login.go index d1d56ceb0..cb2615e55 100644 --- a/client/ssh/login.go +++ b/client/ssh/login.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/server.go b/client/ssh/server.go index 1f2001d0f..8c5db2547 100644 --- a/client/ssh/server.go +++ b/client/ssh/server.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/server_mock.go b/client/ssh/server_mock.go index cc080ffdb..76f43fd4e 100644 --- a/client/ssh/server_mock.go +++ b/client/ssh/server_mock.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import "context" diff --git a/client/ssh/server_test.go b/client/ssh/server_test.go index 5caca1834..1f310c2bb 100644 --- a/client/ssh/server_test.go +++ b/client/ssh/server_test.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/ssh_js.go b/client/ssh/ssh_js.go new file mode 100644 index 000000000..8cea88702 --- /dev/null +++ b/client/ssh/ssh_js.go @@ -0,0 +1,137 @@ +package ssh + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "errors" + "strings" + + "golang.org/x/crypto/ssh" +) + +var ErrSSHNotSupported = errors.New("SSH is not supported in WASM environment") + +// Server is a dummy SSH server interface for WASM. +type Server interface { + Start() error + Stop() error + EnableSSH(enabled bool) + AddAuthorizedKey(peer string, key string) error + RemoveAuthorizedKey(key string) +} + +type dummyServer struct{} + +func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { + return &dummyServer{}, nil +} + +func NewServer(addr string) Server { + return &dummyServer{} +} + +func (s *dummyServer) Start() error { + return ErrSSHNotSupported +} + +func (s *dummyServer) Stop() error { + return nil +} + +func (s *dummyServer) EnableSSH(enabled bool) { +} + +func (s *dummyServer) AddAuthorizedKey(peer string, key string) error { + return nil +} + +func (s *dummyServer) RemoveAuthorizedKey(key string) { +} + +type Client struct{} + +func NewClient(ctx context.Context, addr string, config interface{}, recorder *SessionRecorder) (*Client, error) { + return nil, ErrSSHNotSupported +} + +func (c *Client) Close() error { + return nil +} + +func (c *Client) Run(command []string) error { + return ErrSSHNotSupported +} + +type SessionRecorder struct{} + +func NewSessionRecorder() *SessionRecorder { + return &SessionRecorder{} +} + +func (r *SessionRecorder) Record(session string, data []byte) { +} + +func GetUserShell() string { + return "/bin/sh" +} + +func LookupUserInfo(username string) (string, string, error) { + return "", "", ErrSSHNotSupported +} + +const DefaultSSHPort = 44338 + +const ED25519 = "ed25519" + +func isRoot() bool { + return false +} + +func GeneratePrivateKey(keyType string) ([]byte, error) { + if keyType != ED25519 { + return nil, errors.New("only ED25519 keys are supported in WASM") + } + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, err + } + + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, err + } + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Bytes, + } + + pemBytes := pem.EncodeToMemory(pemBlock) + return pemBytes, nil +} + +func GeneratePublicKey(privateKey []byte) ([]byte, error) { + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + block, _ := pem.Decode(privateKey) + if block != nil { + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + signer, err = ssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + } else { + return nil, err + } + } + + pubKeyBytes := ssh.MarshalAuthorizedKey(signer.PublicKey()) + return []byte(strings.TrimSpace(string(pubKeyBytes))), nil +} diff --git a/client/ssh/util.go b/client/ssh/util.go index cf5f1396e..a54a609bc 100644 --- a/client/ssh/util.go +++ b/client/ssh/util.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/system/info_js.go b/client/system/info_js.go new file mode 100644 index 000000000..994d439a7 --- /dev/null +++ b/client/system/info_js.go @@ -0,0 +1,231 @@ +package system + +import ( + "context" + "runtime" + "strings" + "syscall/js" + + "github.com/netbirdio/netbird/version" +) + +// UpdateStaticInfoAsync is a no-op on JS as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + +// GetInfo retrieves system information for WASM environment +func GetInfo(_ context.Context) *Info { + info := &Info{ + GoOS: runtime.GOOS, + Kernel: runtime.GOARCH, + KernelVersion: runtime.GOARCH, + Platform: runtime.GOARCH, + OS: runtime.GOARCH, + Hostname: "wasm-client", + CPUs: runtime.NumCPU(), + NetbirdVersion: version.NetbirdVersion(), + } + + collectBrowserInfo(info) + collectLocationInfo(info) + collectSystemInfo(info) + return info +} + +func collectBrowserInfo(info *Info) { + navigator := js.Global().Get("navigator") + if navigator.IsUndefined() { + return + } + + collectUserAgent(info, navigator) + collectPlatform(info, navigator) + collectCPUInfo(info, navigator) +} + +func collectUserAgent(info *Info, navigator js.Value) { + ua := navigator.Get("userAgent") + if ua.IsUndefined() { + return + } + + userAgent := ua.String() + os, osVersion := parseOSFromUserAgent(userAgent) + if os != "" { + info.OS = os + } + if osVersion != "" { + info.OSVersion = osVersion + } +} + +func collectPlatform(info *Info, navigator js.Value) { + // Try regular platform property + if plat := navigator.Get("platform"); !plat.IsUndefined() { + if platStr := plat.String(); platStr != "" { + info.Platform = platStr + } + } + + // Try newer userAgentData API for more accurate platform + userAgentData := navigator.Get("userAgentData") + if userAgentData.IsUndefined() { + return + } + + platformInfo := userAgentData.Get("platform") + if !platformInfo.IsUndefined() { + if platStr := platformInfo.String(); platStr != "" { + info.Platform = platStr + } + } +} + +func collectCPUInfo(info *Info, navigator js.Value) { + hardwareConcurrency := navigator.Get("hardwareConcurrency") + if !hardwareConcurrency.IsUndefined() { + info.CPUs = hardwareConcurrency.Int() + } +} + +func collectLocationInfo(info *Info) { + location := js.Global().Get("location") + if location.IsUndefined() { + return + } + + if host := location.Get("hostname"); !host.IsUndefined() { + hostnameStr := host.String() + if hostnameStr != "" && hostnameStr != "localhost" { + info.Hostname = hostnameStr + } + } +} + +func checkFileAndProcess(_ []string) ([]File, error) { + return []File{}, nil +} + +func collectSystemInfo(info *Info) { + navigator := js.Global().Get("navigator") + if navigator.IsUndefined() { + return + } + + if vendor := navigator.Get("vendor"); !vendor.IsUndefined() { + info.SystemManufacturer = vendor.String() + } + + if product := navigator.Get("product"); !product.IsUndefined() { + info.SystemProductName = product.String() + } + + if userAgent := navigator.Get("userAgent"); !userAgent.IsUndefined() { + ua := userAgent.String() + info.Environment = detectEnvironmentFromUA(ua) + } +} + +func parseOSFromUserAgent(userAgent string) (string, string) { + if userAgent == "" { + return "", "" + } + + switch { + case strings.Contains(userAgent, "Windows NT"): + return parseWindowsVersion(userAgent) + case strings.Contains(userAgent, "Mac OS X"): + return parseMacOSVersion(userAgent) + case strings.Contains(userAgent, "FreeBSD"): + return "FreeBSD", "" + case strings.Contains(userAgent, "OpenBSD"): + return "OpenBSD", "" + case strings.Contains(userAgent, "NetBSD"): + return "NetBSD", "" + case strings.Contains(userAgent, "Linux"): + return parseLinuxVersion(userAgent) + case strings.Contains(userAgent, "iPhone") || strings.Contains(userAgent, "iPad"): + return parseiOSVersion(userAgent) + case strings.Contains(userAgent, "CrOS"): + return "ChromeOS", "" + default: + return "", "" + } +} + +func parseWindowsVersion(userAgent string) (string, string) { + switch { + case strings.Contains(userAgent, "Windows NT 10.0; Win64; x64"): + return "Windows", "10/11" + case strings.Contains(userAgent, "Windows NT 10.0"): + return "Windows", "10" + case strings.Contains(userAgent, "Windows NT 6.3"): + return "Windows", "8.1" + case strings.Contains(userAgent, "Windows NT 6.2"): + return "Windows", "8" + case strings.Contains(userAgent, "Windows NT 6.1"): + return "Windows", "7" + default: + return "Windows", "Unknown" + } +} + +func parseMacOSVersion(userAgent string) (string, string) { + idx := strings.Index(userAgent, "Mac OS X ") + if idx == -1 { + return "macOS", "Unknown" + } + + versionStart := idx + len("Mac OS X ") + versionEnd := strings.Index(userAgent[versionStart:], ")") + if versionEnd <= 0 { + return "macOS", "Unknown" + } + + ver := userAgent[versionStart : versionStart+versionEnd] + ver = strings.ReplaceAll(ver, "_", ".") + return "macOS", ver +} + +func parseLinuxVersion(userAgent string) (string, string) { + if strings.Contains(userAgent, "Android") { + return "Android", extractAndroidVersion(userAgent) + } + if strings.Contains(userAgent, "Ubuntu") { + return "Ubuntu", "" + } + return "Linux", "" +} + +func parseiOSVersion(userAgent string) (string, string) { + idx := strings.Index(userAgent, "OS ") + if idx == -1 { + return "iOS", "Unknown" + } + + versionStart := idx + 3 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd <= 0 { + return "iOS", "Unknown" + } + + ver := userAgent[versionStart : versionStart+versionEnd] + ver = strings.ReplaceAll(ver, "_", ".") + return "iOS", ver +} + +func extractAndroidVersion(userAgent string) string { + if idx := strings.Index(userAgent, "Android "); idx != -1 { + versionStart := idx + len("Android ") + versionEnd := strings.IndexAny(userAgent[versionStart:], ";)") + if versionEnd > 0 { + return userAgent[versionStart : versionStart+versionEnd] + } + } + return "Unknown" +} + +func detectEnvironmentFromUA(_ string) Environment { + return Environment{} +} diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go new file mode 100644 index 000000000..d542e2739 --- /dev/null +++ b/client/wasm/cmd/main.go @@ -0,0 +1,245 @@ +//go:build js + +package main + +import ( + "context" + "fmt" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" + + netbird "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/client/wasm/internal/http" + "github.com/netbirdio/netbird/client/wasm/internal/rdp" + "github.com/netbirdio/netbird/client/wasm/internal/ssh" + "github.com/netbirdio/netbird/util" +) + +const ( + clientStartTimeout = 30 * time.Second + clientStopTimeout = 10 * time.Second + defaultLogLevel = "warn" +) + +func main() { + js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor)) + + select {} +} + +func startClient(ctx context.Context, nbClient *netbird.Client) error { + log.Info("Starting NetBird client...") + if err := nbClient.Start(ctx); err != nil { + return err + } + log.Info("NetBird client started successfully") + return nil +} + +// parseClientOptions extracts NetBird options from JavaScript object +func parseClientOptions(jsOptions js.Value) (netbird.Options, error) { + options := netbird.Options{ + DeviceName: "dashboard-client", + LogLevel: defaultLogLevel, + } + + if jwtToken := jsOptions.Get("jwtToken"); !jwtToken.IsNull() && !jwtToken.IsUndefined() { + options.JWTToken = jwtToken.String() + } + + if setupKey := jsOptions.Get("setupKey"); !setupKey.IsNull() && !setupKey.IsUndefined() { + options.SetupKey = setupKey.String() + } + + if privateKey := jsOptions.Get("privateKey"); !privateKey.IsNull() && !privateKey.IsUndefined() { + options.PrivateKey = privateKey.String() + } + + if mgmtURL := jsOptions.Get("managementURL"); !mgmtURL.IsNull() && !mgmtURL.IsUndefined() { + mgmtURLStr := mgmtURL.String() + if mgmtURLStr != "" { + options.ManagementURL = mgmtURLStr + } + } + + if logLevel := jsOptions.Get("logLevel"); !logLevel.IsNull() && !logLevel.IsUndefined() { + options.LogLevel = logLevel.String() + } + + if deviceName := jsOptions.Get("deviceName"); !deviceName.IsNull() && !deviceName.IsUndefined() { + options.DeviceName = deviceName.String() + } + + return options, nil +} + +// createStartMethod creates the start method for the client +func createStartMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), clientStartTimeout) + defer cancel() + + if err := startClient(ctx, client); err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + resolve.Invoke(js.ValueOf(true)) + }) + }) +} + +// createStopMethod creates the stop method for the client +func createStopMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout) + defer cancel() + + if err := client.Stop(ctx); err != nil { + log.Errorf("Error stopping client: %v", err) + reject.Invoke(js.ValueOf(err.Error())) + return + } + + log.Info("NetBird client stopped") + resolve.Invoke(js.ValueOf(true)) + }) + }) +} + +// createSSHMethod creates the SSH connection method +func createSSHMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: requires host and port") + } + + host := args[0].String() + port := args[1].Int() + username := "root" + if len(args) > 2 && args[2].String() != "" { + username = args[2].String() + } + + return createPromise(func(resolve, reject js.Value) { + sshClient := ssh.NewClient(client) + + if err := sshClient.Connect(host, port, username); err != nil { + reject.Invoke(err.Error()) + return + } + + if err := sshClient.StartSession(80, 24); err != nil { + if closeErr := sshClient.Close(); closeErr != nil { + log.Errorf("Error closing SSH client: %v", closeErr) + } + reject.Invoke(err.Error()) + return + } + + jsInterface := ssh.CreateJSInterface(sshClient) + resolve.Invoke(jsInterface) + }) + }) +} + +// createProxyRequestMethod creates the proxyRequest method +func createProxyRequestMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: request details required") + } + + request := args[0] + + return createPromise(func(resolve, reject js.Value) { + response, err := http.ProxyRequest(client, request) + if err != nil { + reject.Invoke(err.Error()) + return + } + resolve.Invoke(response) + }) + }) +} + +// createRDPProxyMethod creates the RDP proxy method +func createRDPProxyMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: hostname and port required") + } + + proxy := rdp.NewRDCleanPathProxy(client) + return proxy.CreateProxy(args[0].String(), args[1].String()) + }) +} + +// createPromise is a helper to create JavaScript promises +func createPromise(handler func(resolve, reject js.Value)) js.Value { + return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { + resolve := promiseArgs[0] + reject := promiseArgs[1] + + go handler(resolve, reject) + + return nil + })) +} + +// createClientObject wraps the NetBird client in a JavaScript object +func createClientObject(client *netbird.Client) js.Value { + obj := make(map[string]interface{}) + + obj["start"] = createStartMethod(client) + obj["stop"] = createStopMethod(client) + obj["createSSHConnection"] = createSSHMethod(client) + obj["proxyRequest"] = createProxyRequestMethod(client) + obj["createRDPProxy"] = createRDPProxyMethod(client) + + return js.ValueOf(obj) +} + +// netBirdClientConstructor acts as a JavaScript constructor function +func netBirdClientConstructor(this js.Value, args []js.Value) any { + return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any { + resolve := promiseArgs[0] + reject := promiseArgs[1] + + if len(args) < 1 { + reject.Invoke(js.ValueOf("Options object required")) + return nil + } + + go func() { + options, err := parseClientOptions(args[0]) + if err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil { + log.Warnf("Failed to initialize logging: %v", err) + } + + log.Infof("Creating NetBird client with options: deviceName=%s, hasJWT=%v, hasSetupKey=%v, mgmtURL=%s", + options.DeviceName, options.JWTToken != "", options.SetupKey != "", options.ManagementURL) + + client, err := netbird.New(options) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("create client: %v", err))) + return + } + + clientObj := createClientObject(client) + log.Info("NetBird client created successfully") + resolve.Invoke(clientObj) + }() + + return nil + })) +} diff --git a/client/wasm/internal/http/http.go b/client/wasm/internal/http/http.go new file mode 100644 index 000000000..cddc9e681 --- /dev/null +++ b/client/wasm/internal/http/http.go @@ -0,0 +1,100 @@ +//go:build js + +package http + +import ( + "fmt" + "io" + log "github.com/sirupsen/logrus" + "net/http" + "strings" + "syscall/js" + "time" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +const ( + httpTimeout = 30 * time.Second + maxResponseSize = 1024 * 1024 // 1MB +) + +// performRequest executes an HTTP request through NetBird and returns the response and body +func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) { + httpClient := nbClient.NewHTTPClient() + httpClient.Timeout = httpTimeout + + req, err := http.NewRequest(method, url, strings.NewReader(string(body))) + if err != nil { + return nil, nil, fmt.Errorf("create request: %w", err) + } + + for key, value := range headers { + req.Header.Set(key, value) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("request failed: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Errorf("failed to close response body: %v", err) + } + }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) + if err != nil { + return nil, nil, fmt.Errorf("read response: %w", err) + } + + return resp, respBody, nil +} + +// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object +func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) { + url := request.Get("url").String() + if url == "" { + return js.Undefined(), fmt.Errorf("URL is required") + } + + method := "GET" + if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() { + method = strings.ToUpper(methodVal.String()) + } + + var requestBody []byte + if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() { + requestBody = []byte(bodyVal.String()) + } + + requestHeaders := make(map[string]string) + if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject { + headerKeys := js.Global().Get("Object").Call("keys", headersVal) + for i := 0; i < headerKeys.Length(); i++ { + key := headerKeys.Index(i).String() + value := headersVal.Get(key).String() + requestHeaders[key] = value + } + } + + resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody) + if err != nil { + return js.Undefined(), err + } + + result := js.Global().Get("Object").New() + result.Set("status", resp.StatusCode) + result.Set("statusText", resp.Status) + result.Set("body", string(body)) + + headers := js.Global().Get("Object").New() + for key, values := range resp.Header { + if len(values) > 0 { + headers.Set(strings.ToLower(key), values[0]) + } + } + result.Set("headers", headers) + + return result, nil +} diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go new file mode 100644 index 000000000..4a23a4bc8 --- /dev/null +++ b/client/wasm/internal/rdp/cert_validation.go @@ -0,0 +1,96 @@ +//go:build js + +package rdp + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + certValidationTimeout = 60 * time.Second +) + +func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) { + if !conn.wsHandlers.Get("onCertificateRequest").Truthy() { + return false, fmt.Errorf("certificate validation handler not configured") + } + + certInfo := js.Global().Get("Object").New() + certInfo.Set("ServerAddr", conn.destination) + + certArray := js.Global().Get("Array").New() + for i, certBytes := range certChain { + uint8Array := js.Global().Get("Uint8Array").New(len(certBytes)) + js.CopyBytesToJS(uint8Array, certBytes) + certArray.SetIndex(i, uint8Array) + } + certInfo.Set("ServerCertChain", certArray) + if len(certChain) > 0 { + cert, err := x509.ParseCertificate(certChain[0]) + if err == nil { + info := js.Global().Get("Object").New() + info.Set("subject", cert.Subject.String()) + info.Set("issuer", cert.Issuer.String()) + info.Set("validFrom", cert.NotBefore.Format(time.RFC3339)) + info.Set("validTo", cert.NotAfter.Format(time.RFC3339)) + info.Set("serialNumber", cert.SerialNumber.String()) + certInfo.Set("CertificateInfo", info) + } + } + + promise := conn.wsHandlers.Call("onCertificateRequest", certInfo) + + resultChan := make(chan bool) + errorChan := make(chan error) + + promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + result := args[0].Bool() + resultChan <- result + return nil + })).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + errorChan <- fmt.Errorf("certificate validation failed") + return nil + })) + + select { + case result := <-resultChan: + if result { + log.Info("Certificate accepted by user") + } else { + log.Info("Certificate rejected by user") + } + return result, nil + case err := <-errorChan: + return false, err + case <-time.After(certValidationTimeout): + return false, fmt.Errorf("certificate validation timeout") + } +} + +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { + return &tls.Config{ + InsecureSkipVerify: true, // We'll validate manually after handshake + VerifyConnection: func(cs tls.ConnectionState) error { + var certChain [][]byte + for _, cert := range cs.PeerCertificates { + certChain = append(certChain, cert.Raw) + } + + accepted, err := p.validateCertificateWithJS(conn, certChain) + if err != nil { + return err + } + if !accepted { + return fmt.Errorf("certificate rejected by user") + } + + return nil + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go new file mode 100644 index 000000000..8062a05cc --- /dev/null +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -0,0 +1,271 @@ +//go:build js + +package rdp + +import ( + "context" + "crypto/tls" + "encoding/asn1" + "fmt" + "io" + "net" + "sync" + "syscall/js" + + log "github.com/sirupsen/logrus" +) + +const ( + RDCleanPathVersion = 3390 + RDCleanPathProxyHost = "rdcleanpath.proxy.local" + RDCleanPathProxyScheme = "ws" +) + +type RDCleanPathPDU struct { + Version int64 `asn1:"tag:0,explicit"` + Error []byte `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathProxy struct { + nbClient interface { + Dial(ctx context.Context, network, address string) (net.Conn, error) + } + activeConnections map[string]*proxyConnection + destinations map[string]string + mu sync.Mutex +} + +type proxyConnection struct { + id string + destination string + rdpConn net.Conn + tlsConn *tls.Conn + wsHandlers js.Value + ctx context.Context + cancel context.CancelFunc +} + +// NewRDCleanPathProxy creates a new RDCleanPath proxy +func NewRDCleanPathProxy(client interface { + Dial(ctx context.Context, network, address string) (net.Conn, error) +}) *RDCleanPathProxy { + return &RDCleanPathProxy{ + nbClient: client, + activeConnections: make(map[string]*proxyConnection), + } +} + +// CreateProxy creates a new proxy endpoint for the given destination +func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { + destination := fmt.Sprintf("%s:%s", hostname, port) + + return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any { + resolve := args[0] + + go func() { + proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections)) + + p.mu.Lock() + if p.destinations == nil { + p.destinations = make(map[string]string) + } + p.destinations[proxyID] = destination + p.mu.Unlock() + + proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID) + + // Register the WebSocket handler for this specific proxy + js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: requires WebSocket argument") + } + + ws := args[0] + p.HandleWebSocketConnection(ws, proxyID) + return nil + })) + + log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination) + resolve.Invoke(proxyURL) + }() + + return nil + })) +} + +// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP +func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) { + p.mu.Lock() + destination := p.destinations[proxyID] + p.mu.Unlock() + + if destination == "" { + log.Errorf("No destination found for proxy ID: %s", proxyID) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + // Don't defer cancel here - it will be called by cleanupConnection + + conn := &proxyConnection{ + id: proxyID, + destination: destination, + wsHandlers: ws, + ctx: ctx, + cancel: cancel, + } + + p.mu.Lock() + p.activeConnections[proxyID] = conn + p.mu.Unlock() + + p.setupWebSocketHandlers(ws, conn) + + log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID) +} + +func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) { + ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return nil + } + + data := args[0] + go p.handleWebSocketMessage(conn, data) + return nil + })) + + ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any { + log.Debug("WebSocket closed by JavaScript") + conn.cancel() + return nil + })) +} + +func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) { + if !data.InstanceOf(js.Global().Get("Uint8Array")) { + return + } + + length := data.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, data) + + if conn.rdpConn != nil || conn.tlsConn != nil { + p.forwardToRDP(conn, bytes) + return + } + + var pdu RDCleanPathPDU + _, err := asn1.Unmarshal(bytes, &pdu) + if err != nil { + log.Warnf("Failed to parse RDCleanPath PDU: %v", err) + n := len(bytes) + if n > 20 { + n = 20 + } + log.Warnf("First %d bytes: %x", n, bytes[:n]) + + if len(bytes) > 0 && bytes[0] == 0x03 { + log.Debug("Received raw RDP packet instead of RDCleanPath PDU") + go p.handleDirectRDP(conn, bytes) + return + } + return + } + + go p.processRDCleanPathPDU(conn, pdu) +} + +func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) { + var writer io.Writer + var connType string + + if conn.tlsConn != nil { + writer = conn.tlsConn + connType = "TLS" + } else if conn.rdpConn != nil { + writer = conn.rdpConn + connType = "TCP" + } else { + log.Error("No RDP connection available") + return + } + + if _, err := writer.Write(bytes); err != nil { + log.Errorf("Failed to write to %s: %v", connType, err) + } +} + +func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) { + defer p.cleanupConnection(conn) + + destination := conn.destination + log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) + + rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + if err != nil { + log.Errorf("Failed to connect to %s: %v", destination, err) + return + } + conn.rdpConn = rdpConn + + _, err = rdpConn.Write(firstPacket) + if err != nil { + log.Errorf("Failed to write first packet: %v", err) + return + } + + response := make([]byte, 1024) + n, err := rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + return + } + + p.sendToWebSocket(conn, response[:n]) + + go p.forwardWSToConn(conn, conn.rdpConn, "TCP") + go p.forwardConnToWS(conn, conn.rdpConn, "TCP") +} + +func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) { + log.Debugf("Cleaning up connection %s", conn.id) + conn.cancel() + if conn.tlsConn != nil { + log.Debug("Closing TLS connection") + if err := conn.tlsConn.Close(); err != nil { + log.Debugf("Error closing TLS connection: %v", err) + } + conn.tlsConn = nil + } + if conn.rdpConn != nil { + log.Debug("Closing TCP connection") + if err := conn.rdpConn.Close(); err != nil { + log.Debugf("Error closing TCP connection: %v", err) + } + conn.rdpConn = nil + } + p.mu.Lock() + delete(p.activeConnections, conn.id) + p.mu.Unlock() +} + +func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { + if conn.wsHandlers.Get("receiveFromGo").Truthy() { + uint8Array := js.Global().Get("Uint8Array").New(len(data)) + js.CopyBytesToJS(uint8Array, data) + conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer")) + } else if conn.wsHandlers.Get("send").Truthy() { + uint8Array := js.Global().Get("Uint8Array").New(len(data)) + js.CopyBytesToJS(uint8Array, data) + conn.wsHandlers.Call("send", uint8Array.Get("buffer")) + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go new file mode 100644 index 000000000..010efa5ea --- /dev/null +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -0,0 +1,251 @@ +//go:build js + +package rdp + +import ( + "crypto/tls" + "encoding/asn1" + "io" + "syscall/js" + + log "github.com/sirupsen/logrus" +) + +func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { + log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) + + if pdu.Version != RDCleanPathVersion { + p.sendRDCleanPathError(conn, "Unsupported version") + return + } + + destination := conn.destination + if pdu.Destination != "" { + destination = pdu.Destination + } + + rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + if err != nil { + log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, "Connection failed") + p.cleanupConnection(conn) + return + } + conn.rdpConn = rdpConn + + // RDP always starts with X.224 negotiation, then determines if TLS is needed + // Modern RDP (since Windows Vista/2008) typically requires TLS + // The X.224 Connection Confirm response will indicate if TLS is required + // For now, we'll attempt TLS for all connections as it's the modern default + p.setupTLSConnection(conn, pdu) +} + +func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { + var x224Response []byte + if len(pdu.X224ConnectionPDU) > 0 { + log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) + _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) + if err != nil { + log.Errorf("Failed to write X.224 PDU: %v", err) + p.sendRDCleanPathError(conn, "Failed to forward X.224") + return + } + + response := make([]byte, 1024) + n, err := conn.rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, "Failed to read X.224 response") + return + } + x224Response = response[:n] + log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) + } + + tlsConfig := p.getTLSConfigWithValidation(conn) + + tlsConn := tls.Client(conn.rdpConn, tlsConfig) + conn.tlsConn = tlsConn + + if err := tlsConn.Handshake(); err != nil { + log.Errorf("TLS handshake failed: %v", err) + p.sendRDCleanPathError(conn, "TLS handshake failed") + return + } + + log.Info("TLS handshake successful") + + // Certificate validation happens during handshake via VerifyConnection callback + var certChain [][]byte + connState := tlsConn.ConnectionState() + if len(connState.PeerCertificates) > 0 { + for _, cert := range connState.PeerCertificates { + certChain = append(certChain, cert.Raw) + } + log.Debugf("Extracted %d certificates from TLS connection", len(certChain)) + } + + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + ServerAddr: conn.destination, + ServerCertChain: certChain, + } + + if len(x224Response) > 0 { + responsePDU.X224ConnectionPDU = x224Response + } + + p.sendRDCleanPathPDU(conn, responsePDU) + + log.Debug("Starting TLS forwarding") + go p.forwardConnToWS(conn, conn.tlsConn, "TLS") + go p.forwardWSToConn(conn, conn.tlsConn, "TLS") + + <-conn.ctx.Done() + log.Debug("TLS connection context done, cleaning up") + p.cleanupConnection(conn) +} + +func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { + if len(pdu.X224ConnectionPDU) > 0 { + log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) + _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) + if err != nil { + log.Errorf("Failed to write X.224 PDU: %v", err) + p.sendRDCleanPathError(conn, "Failed to forward X.224") + return + } + + response := make([]byte, 1024) + n, err := conn.rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, "Failed to read X.224 response") + return + } + + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + X224ConnectionPDU: response[:n], + ServerAddr: conn.destination, + } + + p.sendRDCleanPathPDU(conn, responsePDU) + } else { + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + ServerAddr: conn.destination, + } + p.sendRDCleanPathPDU(conn, responsePDU) + } + + go p.forwardConnToWS(conn, conn.rdpConn, "TCP") + go p.forwardWSToConn(conn, conn.rdpConn, "TCP") + + <-conn.ctx.Done() + log.Debug("TCP connection context done, cleaning up") + p.cleanupConnection(conn) +} + +func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal RDCleanPath PDU: %v", err) + return + } + + log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data)) + p.sendToWebSocket(conn, data) +} + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { + pdu := RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: []byte(errorMsg), + } + + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + + p.sendToWebSocket(conn, data) +} + +func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { + msgChan := make(chan []byte) + errChan := make(chan error) + + handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + if len(args) < 1 { + errChan <- io.EOF + return nil + } + + data := args[0] + if data.InstanceOf(js.Global().Get("Uint8Array")) { + length := data.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, data) + msgChan <- bytes + } + return nil + }) + defer handler.Release() + + conn.wsHandlers.Set("onceGoMessage", handler) + + select { + case msg := <-msgChan: + return msg, nil + case err := <-errChan: + return nil, err + case <-conn.ctx.Done(): + return nil, conn.ctx.Err() + } +} + +func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) { + for { + if conn.ctx.Err() != nil { + return + } + + msg, err := p.readWebSocketMessage(conn) + if err != nil { + if err != io.EOF { + log.Errorf("Failed to read from WebSocket: %v", err) + } + return + } + + _, err = dst.Write(msg) + if err != nil { + log.Errorf("Failed to write to %s: %v", connType, err) + return + } + } +} + +func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) { + buffer := make([]byte, 32*1024) + + for { + if conn.ctx.Err() != nil { + return + } + + n, err := src.Read(buffer) + if err != nil { + if err != io.EOF { + log.Errorf("Failed to read from %s: %v", connType, err) + } + return + } + + if n > 0 { + p.sendToWebSocket(conn, buffer[:n]) + } + } +} diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go new file mode 100644 index 000000000..ca35525eb --- /dev/null +++ b/client/wasm/internal/ssh/client.go @@ -0,0 +1,213 @@ +//go:build js + +package ssh + +import ( + "context" + "fmt" + "io" + "sync" + "time" + + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +const ( + sshDialTimeout = 30 * time.Second +) + +func closeWithLog(c io.Closer, resource string) { + if c != nil { + if err := c.Close(); err != nil { + logrus.Debugf("Failed to close %s: %v", resource, err) + } + } +} + +type Client struct { + nbClient *netbird.Client + sshClient *ssh.Client + session *ssh.Session + stdin io.WriteCloser + stdout io.Reader + stderr io.Reader + mu sync.RWMutex +} + +// NewClient creates a new SSH client +func NewClient(nbClient *netbird.Client) *Client { + return &Client{ + nbClient: nbClient, + } +} + +// Connect establishes an SSH connection through NetBird network +func (c *Client) Connect(host string, port int, username string) error { + addr := fmt.Sprintf("%s:%d", host, port) + logrus.Infof("SSH: Connecting to %s as %s", addr, username) + + var authMethods []ssh.AuthMethod + + nbConfig, err := c.nbClient.GetConfig() + if err != nil { + return fmt.Errorf("get NetBird config: %w", err) + } + if nbConfig.SSHKey == "" { + return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization") + } + + signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey)) + if err != nil { + return fmt.Errorf("parse NetBird SSH private key: %w", err) + } + + pubKey := signer.PublicKey() + logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type()) + + authMethods = append(authMethods, ssh.PublicKeys(signer)) + + config := &ssh.ClientConfig{ + User: username, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: sshDialTimeout, + } + + ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) + defer cancel() + + conn, err := c.nbClient.Dial(ctx, "tcp", addr) + if err != nil { + return fmt.Errorf("dial %s: %w", addr, err) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + closeWithLog(conn, "connection after handshake error") + return fmt.Errorf("SSH handshake: %w", err) + } + + c.sshClient = ssh.NewClient(sshConn, chans, reqs) + logrus.Infof("SSH: Connected to %s", addr) + + return nil +} + +// StartSession starts an SSH session with PTY +func (c *Client) StartSession(cols, rows int) error { + if c.sshClient == nil { + return fmt.Errorf("SSH client not connected") + } + + session, err := c.sshClient.NewSession() + if err != nil { + return fmt.Errorf("create session: %w", err) + } + + c.mu.Lock() + defer c.mu.Unlock() + c.session = session + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + ssh.VINTR: 3, + ssh.VQUIT: 28, + ssh.VERASE: 127, + } + + if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil { + closeWithLog(session, "session after PTY error") + return fmt.Errorf("PTY request: %w", err) + } + + c.stdin, err = session.StdinPipe() + if err != nil { + closeWithLog(session, "session after stdin error") + return fmt.Errorf("get stdin: %w", err) + } + + c.stdout, err = session.StdoutPipe() + if err != nil { + closeWithLog(session, "session after stdout error") + return fmt.Errorf("get stdout: %w", err) + } + + c.stderr, err = session.StderrPipe() + if err != nil { + closeWithLog(session, "session after stderr error") + return fmt.Errorf("get stderr: %w", err) + } + + if err := session.Shell(); err != nil { + closeWithLog(session, "session after shell error") + return fmt.Errorf("start shell: %w", err) + } + + logrus.Info("SSH: Session started with PTY") + return nil +} + +// Write sends data to the SSH session +func (c *Client) Write(data []byte) (int, error) { + c.mu.RLock() + stdin := c.stdin + c.mu.RUnlock() + + if stdin == nil { + return 0, fmt.Errorf("SSH session not started") + } + return stdin.Write(data) +} + +// Read reads data from the SSH session +func (c *Client) Read(buffer []byte) (int, error) { + c.mu.RLock() + stdout := c.stdout + c.mu.RUnlock() + + if stdout == nil { + return 0, fmt.Errorf("SSH session not started") + } + return stdout.Read(buffer) +} + +// Resize updates the terminal size +func (c *Client) Resize(cols, rows int) error { + c.mu.RLock() + session := c.session + c.mu.RUnlock() + + if session == nil { + return fmt.Errorf("SSH session not started") + } + return session.WindowChange(rows, cols) +} + +// Close closes the SSH connection +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.session != nil { + closeWithLog(c.session, "SSH session") + c.session = nil + } + if c.stdin != nil { + closeWithLog(c.stdin, "stdin") + c.stdin = nil + } + c.stdout = nil + c.stderr = nil + + if c.sshClient != nil { + err := c.sshClient.Close() + c.sshClient = nil + return err + } + return nil +} diff --git a/client/wasm/internal/ssh/handlers.go b/client/wasm/internal/ssh/handlers.go new file mode 100644 index 000000000..ea64eb0aa --- /dev/null +++ b/client/wasm/internal/ssh/handlers.go @@ -0,0 +1,78 @@ +//go:build js + +package ssh + +import ( + "io" + "syscall/js" + + "github.com/sirupsen/logrus" +) + +// CreateJSInterface creates a JavaScript interface for the SSH client +func CreateJSInterface(client *Client) js.Value { + jsInterface := js.Global().Get("Object").Call("create", js.Null()) + + jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf(false) + } + + data := args[0] + var bytes []byte + + if data.Type() == js.TypeString { + bytes = []byte(data.String()) + } else { + uint8Array := js.Global().Get("Uint8Array").New(data) + length := uint8Array.Get("length").Int() + bytes = make([]byte, length) + js.CopyBytesToGo(bytes, uint8Array) + } + + _, err := client.Write(bytes) + return js.ValueOf(err == nil) + })) + + jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf(false) + } + cols := args[0].Int() + rows := args[1].Int() + err := client.Resize(cols, rows) + return js.ValueOf(err == nil) + })) + + jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any { + client.Close() + return js.Undefined() + })) + + go readLoop(client, jsInterface) + + return jsInterface +} + +func readLoop(client *Client, jsInterface js.Value) { + buffer := make([]byte, 4096) + for { + n, err := client.Read(buffer) + if err != nil { + if err != io.EOF { + logrus.Debugf("SSH read error: %v", err) + } + if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() { + onclose.Invoke() + } + client.Close() + return + } + + if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() { + uint8Array := js.Global().Get("Uint8Array").New(n) + js.CopyBytesToJS(uint8Array, buffer[:n]) + ondata.Invoke(uint8Array) + } + } +} diff --git a/client/wasm/internal/ssh/key.go b/client/wasm/internal/ssh/key.go new file mode 100644 index 000000000..4868ba30a --- /dev/null +++ b/client/wasm/internal/ssh/key.go @@ -0,0 +1,50 @@ +//go:build js + +package ssh + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "strings" + + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" +) + +// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format +func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) { + keyStr := string(keyPEM) + if !strings.Contains(keyStr, "-----BEGIN") { + keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----") + } + + signer, err := ssh.ParsePrivateKey(keyPEM) + if err == nil { + return signer, nil + } + logrus.Debugf("SSH: Failed to parse as SSH format: %v", err) + + block, _ := pem.Decode(keyPEM) + if block == nil { + keyPreview := string(keyPEM) + if len(keyPreview) > 100 { + keyPreview = keyPreview[:100] + } + return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview) + } + + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err) + if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return ssh.NewSignerFromKey(rsaKey) + } + if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return ssh.NewSignerFromKey(ecKey) + } + return nil, fmt.Errorf("parse private key: %w", err) + } + + return ssh.NewSignerFromKey(key) +} diff --git a/encryption/route53.go b/encryption/route53.go index 3c81ab103..48c7a3a1b 100644 --- a/encryption/route53.go +++ b/encryption/route53.go @@ -1,3 +1,5 @@ +//go:build !js + package encryption import ( diff --git a/flow/client/client.go b/flow/client/client.go index 603fd6882..03a4accaf 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -38,7 +38,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl return nil, fmt.Errorf("parsing url: %w", err) } var opts []grpc.DialOption - if parsedURL.Scheme == "https" { + tlsEnabled := parsedURL.Scheme == "https" + if tlsEnabled { certPool, err := x509.SystemCertPool() if err != nil || certPool == nil { log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) @@ -53,7 +54,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl } opts = append(opts, - nbgrpc.WithCustomDialer(), + nbgrpc.WithCustomDialer(tlsEnabled), grpc.WithIdleTimeout(interval*2), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/go.mod b/go.mod index 23aa45277..c4b629993 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( github.com/c-robinson/iplib v1.0.3 github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 - github.com/coder/websocket v1.8.12 + github.com/coder/websocket v1.8.13 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 github.com/eko/gocache/lib/v4 v4.2.0 diff --git a/go.sum b/go.sum index 7096be3fe..13838b82d 100644 --- a/go.sum +++ b/go.sum @@ -140,8 +140,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII= github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 984a56a39..ddd81daa2 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -10,6 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { @@ -56,8 +58,8 @@ func (s *BaseServer) AuthManager() auth.Manager { }) } -func (s *BaseServer) EphemeralManager() *server.EphemeralManager { - return Create(s, func() *server.EphemeralManager { - return server.NewEphemeralManager(s.Store(), s.AccountManager()) +func (s *BaseServer) EphemeralManager() ephemeral.Manager { + return Create(s, func() ephemeral.Manager { + return manager.NewEphemeralManager(s.Store(), s.AccountManager()) }) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 70f0f93a9..daec4ef6f 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -65,6 +65,10 @@ func (s *BaseServer) AccountManager() account.Manager { if err != nil { log.Fatalf("failed to create account manager: %v", err) } + + s.AfterInit(func(s *BaseServer) { + accountManager.SetEphemeralManager(s.EphemeralManager()) + }) return accountManager }) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index e868c2529..ae9ac4a60 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -6,12 +6,14 @@ import ( "fmt" "net" "net/http" + "net/netip" "strings" "sync" "time" "github.com/google/uuid" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -22,6 +24,8 @@ import ( "github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/wsproxy" + wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" ) @@ -92,12 +96,6 @@ func (s *BaseServer) Start(ctx context.Context) error { s.PeersManager() s.GeoLocationManager() - for _, fn := range s.afterInit { - if fn != nil { - fn(s) - } - } - err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics") if err != nil { return fmt.Errorf("failed to expose metrics: %v", err) @@ -147,7 +145,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } - rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler()) + rootHandler := s.handlerFunc(s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter()) switch { case s.certManager != nil: // a call to certManager.Listener() always creates a new listener so we do it once @@ -176,6 +174,12 @@ func (s *BaseServer) Start(ctx context.Context) error { } } + for _, fn := range s.afterInit { + if fn != nil { + fn(s) + } + } + log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion()) log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String()) s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled) @@ -247,13 +251,17 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) return util.DirectWriteJson(ctx, path, config) } -func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler { +func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { + wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || - strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto") - if request.ProtoMajor == 2 && grpcHeader { + switch { + case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || + strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")): gRPCHandler.ServeHTTP(writer, request) - } else { + case request.URL.Path == wsproxy.ProxyPath: + wsProxy.Handler().ServeHTTP(writer, request) + default: httpHandler.ServeHTTP(writer, request) } }) diff --git a/management/server/account.go b/management/server/account.go index ee9f294a4..dca105ddf 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -35,6 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -74,6 +75,7 @@ type DefaultAccountManager struct { ctx context.Context eventStore activity.Store geo geolocation.Geolocation + ephemeralManager ephemeral.Manager requestBuffer *AccountRequestBuffer @@ -261,6 +263,10 @@ func BuildManager( return am, nil } +func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { + am.ephemeralManager = em +} + func (am *DefaultAccountManager) startWarmup(ctx context.Context) { var initialInterval int64 intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 30fbbbc3e..a1ed9498b 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -12,6 +12,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -56,7 +57,7 @@ type Manager interface { UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) - AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) @@ -125,5 +126,6 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + SetEphemeralManager(em ephemeral.Manager) AllowSync(string, uint64) bool } diff --git a/management/server/account_test.go b/management/server/account_test.go index 81a921bf9..07d2f2383 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -66,7 +66,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account setupKey = key.Key } - _, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer) + _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false) if err != nil { t.Error("expected to add new peer successfully after creating new account, but failed", err) } @@ -1048,10 +1048,10 @@ func TestAccountManager_AddPeer(t *testing.T) { } expectedPeerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1112,10 +1112,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedUserID := userID - peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy) return @@ -1429,10 +1429,10 @@ func TestAccountManager_DeletePeer(t *testing.T) { peerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey, Meta: nbpeer.PeerSystemMeta{Hostname: peerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1805,11 +1805,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1861,11 +1861,11 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, @@ -1904,11 +1904,11 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2952,14 +2952,14 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, } expectedPeerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Status: &nbpeer.PeerStatus{ Connected: true, LastSeen: time.Now().UTC(), }, - }) + }, false) if err != nil { t.Fatalf("expecting peer to be added, got failure %v", err) } @@ -3552,16 +3552,16 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { key2, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) require.NoError(t, err, "unable to add peer1") - peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) require.NoError(t, err, "unable to add peer2") t.Run("update peer IP successfully", func(t *testing.T) { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 55a1bbe66..a2a2ce529 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -281,11 +281,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return nil, err } - savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) + savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false) if err != nil { return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false) if err != nil { return nil, err } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 60a00207e..1177eefff 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -22,6 +22,7 @@ import ( integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/store" @@ -55,7 +56,7 @@ type GRPCServer struct { config *nbconfig.Config secretsManager SecretsManager appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager + ephemeralManager ephemeral.Manager peerLocks sync.Map authManager auth.Manager @@ -73,7 +74,7 @@ func NewServer( peersUpdateManager *PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, - ephemeralManager *EphemeralManager, + ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, ) (*GRPCServer, error) { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index af501e151..4b33495de 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -32,6 +32,7 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS") } // NewHandler creates a new peers Handler @@ -318,6 +319,88 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } +func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + peerID := vars["peerId"] + if len(peerID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + return + } + + var req api.PeerTemporaryAccessRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + newPeer := &nbpeer.Peer{} + newPeer.FromAPITemporaryAccessRequest(&req) + + targetPeer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + for _, rule := range req.Rules { + protocol, portRange, err := types.ParseRuleString(rule) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + policy := &types.Policy{ + AccountID: userAuth.AccountId, + Description: "Temporary access policy for peer " + peer.Name, + Name: "Temporary access policy for peer " + peer.Name, + Enabled: true, + Rules: []*types.PolicyRule{{ + Name: "Temporary access rule", + Description: "Temporary access rule", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + SourceResource: types.Resource{ + Type: types.ResourceTypePeer, + ID: peer.ID, + }, + DestinationResource: types.Resource{ + Type: types.ResourceTypePeer, + ID: targetPeer.ID, + }, + Bidirectional: false, + Protocol: protocol, + PortRanges: []types.RulePortRange{portRange}, + }}, + } + + _, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + } + + resp := &api.PeerTemporaryAccessResponse{ + Id: peer.ID, + Name: peer.Name, + Rules: req.Rules, + } + + util.WriteJSONObject(r.Context(), w, resp) +} + func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer { accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) for _, p := range netMap.Peers { diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index ba4997d22..a34d2086b 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -460,7 +461,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - ephemeralMgr := NewEphemeralManager(store, accountManager) + ephemeralMgr := manager.NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) if err != nil { return nil, nil, "", cleanup, err diff --git a/management/server/management_test.go b/management/server/management_test.go index 61dc46d87..1a5e47354 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -228,7 +229,7 @@ func startServer( peersUpdateManager, secretsManager, nil, - nil, + &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, ) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 003385eb5..d160e7269 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -15,6 +15,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -41,7 +42,7 @@ type MockAccountManager struct { DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) @@ -351,12 +352,14 @@ func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string // AddPeer mock implementation of AddPeer from server.AccountManager interface func (am *MockAccountManager) AddPeer( ctx context.Context, + accountID string, setupKey string, userId string, peer *nbpeer.Peer, + temporary bool, ) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { - return am.AddPeerFunc(ctx, setupKey, userId, peer) + return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } @@ -972,6 +975,11 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } +// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface +func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) { + // Mock implementation - does nothing +} + func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { if am.AllowSyncFunc != nil { return am.AllowSyncFunc(key, hash) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 959e7856a..6c985410c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -876,11 +876,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) + _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false) if err != nil { return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false) if err != nil { return nil, err } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 294f51676..66484d120 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -132,7 +132,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc res := nbtypes.Resource{ ID: resource.ID, - Type: resource.Type.String(), + Type: nbtypes.ResourceType(resource.Type.String()), } for _, groupID := range resource.GroupIDs { event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res) @@ -265,7 +265,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, userID string, newResource, oldResource *types.NetworkResource) ([]func(), error) { res := nbtypes.Resource{ ID: newResource.ID, - Type: newResource.Type.String(), + Type: nbtypes.ResourceType(newResource.Type.String()), } oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID) diff --git a/management/server/peer.go b/management/server/peer.go index 81f037499..ea4617af0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -450,7 +450,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") @@ -482,8 +482,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var ephemeral bool var groupsToAdd []string var allowExtraDNSLabels bool - var accountID string - var isEphemeral bool if addedByUser { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { @@ -492,10 +490,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if user.PendingApproval { return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") } - groupsToAdd = user.AutoGroups + if temporary { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create) + if err != nil { + return nil, nil, nil, status.NewPermissionValidationError(err) + } + + if !allowed { + return nil, nil, nil, status.NewPermissionDeniedError() + } + } else { + accountID = user.AccountID + groupsToAdd = user.AutoGroups + } opEvent.InitiatorID = userID opEvent.Activity = activity.PeerAddedByUser - accountID = user.AccountID } else { // Validate the setup key sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey) @@ -516,13 +525,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s setupKeyName = sk.Name allowExtraDNSLabels = sk.AllowExtraDNSLabels accountID = sk.AccountID - isEphemeral = sk.Ephemeral if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") } } opEvent.AccountID = accountID + if temporary { + ephemeral = true + } + if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { if am.idpManager != nil { userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) @@ -549,10 +561,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s SSHKey: peer.SSHKey, LastLogin: ®istrationTime, CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, + LoginExpirationEnabled: addedByUser && !temporary, Ephemeral: ephemeral, Location: peer.Location, - InactivityExpirationEnabled: addedByUser, + InactivityExpirationEnabled: addedByUser && !temporary, ExtraDNSLabels: peer.ExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels, } @@ -588,7 +600,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var freeLabel string - if isEphemeral || attempt > 1 { + if ephemeral || attempt > 1 { freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) @@ -622,6 +634,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed adding peer to All group: %w", err) } + if temporary { + // we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually + am.ephemeralManager.OnPeerDisconnected(ctx, newPeer) + } + if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -790,7 +807,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo ExtraDNSLabels: login.ExtraDNSLabels, } - return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) + return am.AddPeer(ctx, "", login.SetupKey, login.UserID, newPeer, false) } log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) @@ -877,6 +894,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer if peer.SSHKey != login.SSHKey { peer.SSHKey = login.SSHKey shouldStorePeer = true + updateRemotePeers = true } if !peer.AllowExtraDNSLabels && len(login.ExtraDNSLabels) > 0 { @@ -1540,6 +1558,26 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return nil, err } + peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peer.ID) + if err != nil { + return nil, err + } + for _, rule := range peerPolicyRules { + policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID) + if err != nil { + return nil, err + } + + err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID) + if err != nil { + return nil, err + } + + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + }) + } + if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 6a6d1c91d..f89f10dac 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -8,6 +8,7 @@ import ( "time" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/http/api" ) // Peer represents a machine connected to the network. @@ -334,6 +335,17 @@ func (p *Peer) UpdateLastLogin() *Peer { return p } +func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest) { + p.Ephemeral = true + p.Name = a.Name + p.Key = a.WgPubKey + p.Meta = PeerSystemMeta{ + Hostname: a.Name, + GoOS: "js", + OS: "js", + } +} + func (f Flags) isEqual(other Flags) bool { return f.RosenpassEnabled == other.RosenpassEnabled && f.RosenpassPermissive == other.RosenpassPermissive && diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 31c309430..734536d7b 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -193,10 +193,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -207,10 +207,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) @@ -266,10 +266,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -280,10 +280,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) return } - peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -442,10 +442,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -456,10 +456,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) @@ -514,10 +514,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -530,10 +530,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // the second peer added with a setup key - peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Fatal(err) return @@ -702,19 +702,19 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, _, err = manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return } - _, _, _, err = manager.AddPeer(context.Background(), "", adminUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", adminUser, &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1300,7 +1300,7 @@ func Test_RegisterPeerByUser(t *testing.T) { }, } - addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) + addedPeer, _, _, err := am.AddPeer(context.Background(), "", "", existingUserID, newPeer, false) require.NoError(t, err) assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels) @@ -1422,7 +1422,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ExtraDNSLabels: newPeerTemplate.ExtraDNSLabels, } - addedPeer, _, _, err := am.AddPeer(context.Background(), tc.existingSetupKeyID, "", currentPeer) + addedPeer, _, _, err := am.AddPeer(context.Background(), "", tc.existingSetupKeyID, "", currentPeer, false) if tc.expectAddPeerError { require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID) @@ -1523,7 +1523,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { SSHEnabled: false, } - _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) + _, _, _, err = am.AddPeer(context.Background(), "", faultyKey, "", newPeer, false) require.Error(t, err) _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key) @@ -1658,7 +1658,7 @@ func Test_LoginPeer(t *testing.T) { if sk.AllowExtraDNSLabels { currentPeer.ExtraDNSLabels = newPeerTemplate.ExtraDNSLabels } - _, _, _, err = am.AddPeer(context.Background(), tc.setupKey, "", currentPeer) + _, _, _, err = am.AddPeer(context.Background(), "", tc.setupKey, "", currentPeer, false) require.NoError(t, err, "Expected no error when adding peer with setup key: %s", tc.setupKey) loginInput := types.PeerLogin{ @@ -1797,10 +1797,10 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -1918,11 +1918,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -1982,11 +1982,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + peer5, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -2037,11 +2037,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{ + peer6, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser3", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -2208,7 +2208,7 @@ func Test_AddPeer(t *testing.T) { <-start - _, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer) + _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", newPeer, false) if err != nil { errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) return @@ -2416,7 +2416,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { }, } - _, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer) + _, _, _, err = manager.AddPeer(context.Background(), "", "", pendingUser.Id, peer, false) require.Error(t, err) assert.Contains(t, err.Error(), "user pending approval cannot add peers") } @@ -2451,7 +2451,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { }, } - _, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer) + _, _, _, err = manager.AddPeer(context.Background(), "", "", regularUser.Id, peer, false) require.NoError(t, err, "Regular user should be able to add peers") } @@ -2494,7 +2494,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { WtVersion: "0.28.0", }, } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer) + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", pendingUser.Id, newPeer, false) require.NoError(t, err) // Now set the user back to pending approval after peer was created @@ -2550,7 +2550,7 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { WtVersion: "0.28.0", }, } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer) + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", regularUser.Id, newPeer, false) require.NoError(t, err) // Try to login with regular user diff --git a/management/server/peers/ephemeral/interface.go b/management/server/peers/ephemeral/interface.go new file mode 100644 index 000000000..a1605b3b9 --- /dev/null +++ b/management/server/peers/ephemeral/interface.go @@ -0,0 +1,14 @@ +package ephemeral + +import ( + "context" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +type Manager interface { + LoadInitialPeers(ctx context.Context) + Stop() + OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) + OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) +} diff --git a/management/server/ephemeral.go b/management/server/peers/ephemeral/manager/ephemeral.go similarity index 99% rename from management/server/ephemeral.go rename to management/server/peers/ephemeral/manager/ephemeral.go index e3cb5459a..062ba69d2 100644 --- a/management/server/ephemeral.go +++ b/management/server/peers/ephemeral/manager/ephemeral.go @@ -1,4 +1,4 @@ -package server +package manager import ( "context" diff --git a/management/server/ephemeral_test.go b/management/server/peers/ephemeral/manager/ephemeral_test.go similarity index 75% rename from management/server/ephemeral_test.go rename to management/server/peers/ephemeral/manager/ephemeral_test.go index d07b9a422..fc7525c29 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/peers/ephemeral/manager/ephemeral_test.go @@ -1,4 +1,4 @@ -package server +package manager import ( "context" @@ -7,12 +7,15 @@ import ( "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + nbdns "github.com/netbirdio/netbird/dns" nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" ) type MockStore struct { @@ -223,3 +226,57 @@ func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) store.account.Peers[p.ID] = p } } + +// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id +func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account { + log.WithContext(ctx).Debugf("creating new account") + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[route.ID]*route.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + owner := types.NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + log.WithContext(ctx).Debugf("created new account %s", accountID) + + acc := &types.Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userID, + Domain: domain, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + RoutingPeerDNSResolutionEnabled: true, + }, + Onboarding: types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, + } + + if err := acc.AddAllGroup(disableDefaultPolicy); err != nil { + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + } + return acc +} diff --git a/management/server/policy.go b/management/server/policy.go index 3adee6397..9e4b3f73a 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -151,6 +151,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return false, nil } + for _, rule := range existingPolicy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) if err != nil { return false, err @@ -161,6 +167,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a } } + for _, rule := range policy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 027938320..382d026c8 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2037,6 +2037,25 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) }) } +func (s *SqlStore) GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, resourceID string) ([]*types.PolicyRule, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var policyRules []*types.PolicyRule + resourceIDPattern := `%"ID":"` + resourceID + `"%` + result := tx.Where("source_resource LIKE ? OR destination_resource LIKE ?", resourceIDPattern, resourceIDPattern). + Find(&policyRules) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get policy rules for resource id from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policy rules for resource id from store") + } + + return policyRules, nil +} + // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { tx := s.db diff --git a/management/server/store/store.go b/management/server/store/store.go index 3c9d896b0..21b660d96 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -202,6 +202,7 @@ type Store interface { IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error + GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error) } const ( diff --git a/management/server/types/account.go b/management/server/types/account.go index a69d3bb08..f830023c7 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1001,8 +1001,20 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P continue } - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap) + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers, peerInSources = a.getPeerFromResource(rule.SourceResource, peer.ID) + } else { + sourcePeers, peerInSources = a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers, peerInDestinations = a.getPeerFromResource(rule.DestinationResource, peer.ID) + } else { + destinationPeers, peerInDestinations = a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap) + } if rule.Bidirectional { if peerInSources { @@ -1124,6 +1136,15 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe return filteredPeers, peerInGroups } +func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) { + peer := a.GetPeer(resource.ID) + if peer == nil { + return []*nbpeer.Peer{}, false + } + + return []*nbpeer.Peer{peer}, resource.ID == peerID +} + // validatePostureChecksOnPeer validates the posture checks on a peer func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { peer, ok := a.Peers[peerID] @@ -1379,7 +1400,12 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st addedResourceRoute := false for _, policy := range resourcePolicies[resource.ID] { - peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } if addSourcePeers { for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { allSourcePeers[pID] = struct{}{} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 17964ed1f..5e86a87c6 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,5 +1,12 @@ package types +import ( + "errors" + "fmt" + "strconv" + "strings" +) + const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -134,3 +141,83 @@ func (p *Policy) SourceGroups() []string { return groupIDs } + +func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) { + rule = strings.TrimSpace(strings.ToLower(rule)) + if rule == "all" { + return PolicyRuleProtocolALL, RulePortRange{}, nil + } + if rule == "icmp" { + return PolicyRuleProtocolICMP, RulePortRange{}, nil + } + + split := strings.Split(rule, "/") + if len(split) != 2 { + return "", RulePortRange{}, errors.New("invalid rule format: expected protocol/port or protocol/port-range") + } + + protoStr := strings.TrimSpace(split[0]) + portStr := strings.TrimSpace(split[1]) + + var protocol PolicyRuleProtocolType + switch protoStr { + case "tcp": + protocol = PolicyRuleProtocolTCP + case "udp": + protocol = PolicyRuleProtocolUDP + case "icmp": + return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'") + default: + return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr) + } + + portRange, err := parsePortRange(portStr) + if err != nil { + return "", RulePortRange{}, err + } + + return protocol, portRange, nil +} + +func parsePortRange(portStr string) (RulePortRange, error) { + if strings.Contains(portStr, "-") { + rangeParts := strings.Split(portStr, "-") + if len(rangeParts) != 2 { + return RulePortRange{}, fmt.Errorf("invalid port range %q", portStr) + } + start, err := parsePort(strings.TrimSpace(rangeParts[0])) + if err != nil { + return RulePortRange{}, err + } + end, err := parsePort(strings.TrimSpace(rangeParts[1])) + if err != nil { + return RulePortRange{}, err + } + if start > end { + return RulePortRange{}, fmt.Errorf("invalid port range: start %d > end %d", start, end) + } + return RulePortRange{Start: uint16(start), End: uint16(end)}, nil + } + + p, err := parsePort(portStr) + if err != nil { + return RulePortRange{}, err + } + + return RulePortRange{Start: uint16(p), End: uint16(p)}, nil +} + +func parsePort(portStr string) (int, error) { + + if portStr == "" { + return 0, errors.New("empty port") + } + p, err := strconv.Atoi(portStr) + if err != nil { + return 0, fmt.Errorf("invalid port %q: %w", portStr, err) + } + if p < 1 || p > 65535 { + return 0, fmt.Errorf("port out of range (1–65535): %d", p) + } + return p, nil +} diff --git a/management/server/types/resource.go b/management/server/types/resource.go index 84d8e4b88..8347d8c03 100644 --- a/management/server/types/resource.go +++ b/management/server/types/resource.go @@ -4,9 +4,18 @@ import ( "github.com/netbirdio/netbird/shared/management/http/api" ) +type ResourceType string + +const ( + ResourceTypePeer ResourceType = "peer" + ResourceTypeDomain ResourceType = "domain" + ResourceTypeHost ResourceType = "host" + ResourceTypeSubnet ResourceType = "subnet" +) + type Resource struct { ID string - Type string + Type ResourceType } func (r *Resource) ToAPIResponse() *api.Resource { @@ -26,5 +35,5 @@ func (r *Resource) FromAPIRequest(req *api.Resource) { } r.ID = req.Id - r.Type = string(req.Type) + r.Type = ResourceType(req.Type) } diff --git a/management/server/user_test.go b/management/server/user_test.go index 9638559f9..5920a2a33 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1439,10 +1439,10 @@ func TestUserAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + peer4, _, _, err := manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) // updating user with linked peers should update account peers and send peer update diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index becc10ded..d4a9f1823 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/internals/server/config" @@ -27,6 +28,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -117,7 +119,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { groupsManager := groups.NewManagerMock() secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{}) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 9a531b2ff..93578b1ae 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -507,6 +507,48 @@ components: - serial_number - extra_dns_labels - ephemeral + PeerTemporaryAccessRequest: + type: object + properties: + name: + description: Peer's hostname + type: string + example: temp-host-1 + wg_pub_key: + description: Peer's WireGuard public key + type: string + example: "n0r3pL4c3h0ld3rK3y==" + rules: + description: List of temporary access rules + type: array + items: + type: string + example: "tcp/80" + required: + - name + - wg_pub_key + - rules + PeerTemporaryAccessResponse: + type: object + properties: + name: + description: Peer's hostname + type: string + example: temp-host-1 + id: + description: Peer ID + type: string + example: chacbco6lnnbn6cg5s90 + rules: + description: List of temporary access rules + type: array + items: + type: string + example: "tcp/80" + required: + - name + - id + - rules AccessiblePeer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -1404,7 +1446,8 @@ components: allOf: - $ref: '#/components/schemas/NetworkResourceType' - type: string - example: host + enum: ["peer"] + example: peer NetworkRequest: type: object properties: @@ -2793,6 +2836,42 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/peers/{peerId}/temporary-access: + post: + summary: Create a Temporary Access Peer + description: Creates a temporary access peer that can be used to access this peer and this peer only. The temporary access peer and its access policies will be automatically deleted after it disconnects. + tags: [ Peers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: peerId + required: true + schema: + type: string + description: The unique identifier of a peer + requestBody: + description: Temporary Access Peer create request + content: + 'application/json': + schema: + $ref: '#/components/schemas/PeerTemporaryAccessRequest' + responses: + '200': + description: Temporary Access Peer response + content: + application/json: + schema: + $ref: '#/components/schemas/PeerTemporaryAccessResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/peers/{peerId}/ingress/ports: get: x-cloud-only: true diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 28b89633c..3dbb32ef6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -168,6 +168,7 @@ const ( const ( ResourceTypeDomain ResourceType = "domain" ResourceTypeHost ResourceType = "host" + ResourceTypePeer ResourceType = "peer" ResourceTypeSubnet ResourceType = "subnet" ) @@ -1221,6 +1222,30 @@ type PeerRequest struct { SshEnabled bool `json:"ssh_enabled"` } +// PeerTemporaryAccessRequest defines model for PeerTemporaryAccessRequest. +type PeerTemporaryAccessRequest struct { + // Name Peer's hostname + Name string `json:"name"` + + // Rules List of temporary access rules + Rules []string `json:"rules"` + + // WgPubKey Peer's WireGuard public key + WgPubKey string `json:"wg_pub_key"` +} + +// PeerTemporaryAccessResponse defines model for PeerTemporaryAccessResponse. +type PeerTemporaryAccessResponse struct { + // Id Peer ID + Id string `json:"id"` + + // Name Peer's hostname + Name string `json:"name"` + + // Rules List of temporary access rules + Rules []string `json:"rules"` +} + // PersonalAccessToken defines model for PersonalAccessToken. type PersonalAccessToken struct { // CreatedAt Date the token was created @@ -1949,6 +1974,9 @@ type PostApiPeersPeerIdIngressPortsJSONRequestBody = IngressPortAllocationReques // PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody defines body for PutApiPeersPeerIdIngressPortsAllocationId for application/json ContentType. type PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody = IngressPortAllocationRequest +// PostApiPeersPeerIdTemporaryAccessJSONRequestBody defines body for PostApiPeersPeerIdTemporaryAccess for application/json ContentType. +type PostApiPeersPeerIdTemporaryAccessJSONRequestBody = PeerTemporaryAccessRequest + // PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType. type PostApiPoliciesJSONRequestBody = PolicyUpdate diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index bf614e8aa..8381d6682 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.24.3 +// protoc v6.32.0 // source: management.proto package proto diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index 5dabc5742..57a98614d 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -9,11 +9,8 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client/dialer" - "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" - "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" "github.com/netbirdio/netbird/shared/relay/healthcheck" "github.com/netbirdio/netbird/shared/relay/messages" ) @@ -296,14 +293,7 @@ func (c *Client) Close() error { } func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { - // Force WebSocket for MTUs larger than default to avoid QUIC DATAGRAM frame size issues - var dialers []dialer.DialeFn - if c.mtu > 0 && c.mtu > iface.DefaultMTU { - c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) - dialers = []dialer.DialeFn{ws.Dialer{}} - } else { - dialers = []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} - } + dialers := c.getDialers() rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) conn, err := rd.Dial() diff --git a/shared/relay/client/dialer/ws/conn.go b/shared/relay/client/dialer/ws/conn.go index 0086b702b..d5b719f51 100644 --- a/shared/relay/client/dialer/ws/conn.go +++ b/shared/relay/client/dialer/ws/conn.go @@ -38,8 +38,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { } func (c *Conn) Write(b []byte) (n int, err error) { - err = c.Conn.Write(c.ctx, websocket.MessageBinary, b) - return 0, err + return 0, c.Conn.Write(c.ctx, websocket.MessageBinary, b) } func (c *Conn) RemoteAddr() net.Addr { diff --git a/shared/relay/client/dialer/ws/dialopts_generic.go b/shared/relay/client/dialer/ws/dialopts_generic.go new file mode 100644 index 000000000..9dfe698d0 --- /dev/null +++ b/shared/relay/client/dialer/ws/dialopts_generic.go @@ -0,0 +1,11 @@ +//go:build !js + +package ws + +import "github.com/coder/websocket" + +func createDialOptions() *websocket.DialOptions { + return &websocket.DialOptions{ + HTTPClient: httpClientNbDialer(), + } +} diff --git a/shared/relay/client/dialer/ws/dialopts_js.go b/shared/relay/client/dialer/ws/dialopts_js.go new file mode 100644 index 000000000..7eac27531 --- /dev/null +++ b/shared/relay/client/dialer/ws/dialopts_js.go @@ -0,0 +1,10 @@ +//go:build js + +package ws + +import "github.com/coder/websocket" + +func createDialOptions() *websocket.DialOptions { + // WASM version doesn't support HTTPClient + return &websocket.DialOptions{} +} diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index ef6bd6b3c..66fff3447 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -32,9 +32,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { return nil, err } - opts := &websocket.DialOptions{ - HTTPClient: httpClientNbDialer(), - } + opts := createDialOptions() parsedURL, err := url.Parse(wsURL) if err != nil { diff --git a/shared/relay/client/dialers_generic.go b/shared/relay/client/dialers_generic.go new file mode 100644 index 000000000..a8ed79961 --- /dev/null +++ b/shared/relay/client/dialers_generic.go @@ -0,0 +1,19 @@ +//go:build !js + +package client + +import ( + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/shared/relay/client/dialer" + "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" + "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" +) + +// getDialers returns the list of dialers to use for connecting to the relay server. +func (c *Client) getDialers() []dialer.DialeFn { + if c.mtu > 0 && c.mtu > iface.DefaultMTU { + c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) + return []dialer.DialeFn{ws.Dialer{}} + } + return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} +} diff --git a/shared/relay/client/dialers_js.go b/shared/relay/client/dialers_js.go new file mode 100644 index 000000000..6bd0e6696 --- /dev/null +++ b/shared/relay/client/dialers_js.go @@ -0,0 +1,13 @@ +//go:build js + +package client + +import ( + "github.com/netbirdio/netbird/shared/relay/client/dialer" + "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" +) + +func (c *Client) getDialers() []dialer.DialeFn { + // JS/WASM build only uses WebSocket transport + return []dialer.DialeFn{ws.Dialer{}} +} diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 1d76fa4e4..e2a69a75b 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -8,14 +8,16 @@ import ( "fmt" "net" "net/http" - // nolint:gosec _ "net/http/pprof" - "strings" + "net/netip" "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "go.opentelemetry.io/otel/metric" "golang.org/x/crypto/acme/autocert" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "github.com/netbirdio/netbird/signal/metrics" @@ -23,6 +25,8 @@ import ( "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/server" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/wsproxy" + wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" log "github.com/sirupsen/logrus" @@ -32,6 +36,8 @@ import ( "google.golang.org/grpc/keepalive" ) +const legacyGRPCPort = 10000 + var ( signalPort int metricsPort int @@ -113,7 +119,7 @@ var ( } proto.RegisterSignalExchangeServer(grpcServer, srv) - grpcRootHandler := grpcHandlerFunc(grpcServer) + grpcRootHandler := grpcHandlerFunc(grpcServer, metricsServer.Meter) if certManager != nil { startServerWithCertManager(certManager, grpcRootHandler) @@ -123,19 +129,30 @@ var ( var grpcListener net.Listener var httpListener net.Listener - // If certManager is configured and signalPort == 443, then the gRPC server has already been started - if certManager == nil || signalPort != 443 { - grpcListener, err = serveGRPC(grpcServer, signalPort) + // Start the main server - always serve HTTP with WebSocket proxy support + // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager + if certManager == nil { + // Without TLS, serve plain HTTP + httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { return err } - log.Infof("running gRPC server: %s", grpcListener.Addr().String()) + log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) + serveHTTP(httpListener, grpcRootHandler) + } else if signalPort != 443 { + // With TLS but not on port 443, serve HTTPS + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + if err != nil { + return err + } + log.Infof("running HTTPS server with WebSocket proxy: %s", httpListener.Addr().String()) + serveHTTP(httpListener, grpcRootHandler) } - if signalPort != 10000 { + if signalPort != legacyGRPCPort { // The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal // are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000. - compatListener, err = serveGRPC(grpcServer, 10000) + compatListener, err = serveGRPC(grpcServer, legacyGRPCPort) if err != nil { return err } @@ -236,11 +253,14 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h } } -func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler { +func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { + wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter)) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") || - strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto") - if r.ProtoMajor == 2 && grpcHeader { + switch { + case r.URL.Path == wsproxy.ProxyPath: + wsProxy.Handler().ServeHTTP(w, r) + default: grpcServer.ServeHTTP(w, r) } }) @@ -257,7 +277,11 @@ func notifyStop(msg string) { func serveHTTP(httpListener net.Listener, handler http.Handler) { go func() { - err := http.Serve(httpListener, handler) + // Use h2c to support HTTP/2 without TLS (needed for gRPC) + h1s := &http.Server{ + Handler: h2c.NewHandler(handler, &http2.Server{}), + } + err := h1s.Serve(httpListener) if err != nil { notifyStop(fmt.Sprintf("failed running HTTP server %v", err)) } diff --git a/util/util_js.go b/util/util_js.go new file mode 100644 index 000000000..8c243cab3 --- /dev/null +++ b/util/util_js.go @@ -0,0 +1,8 @@ +//go:build js + +package util + +// IsAdmin returns false for WASM as there's no admin concept in browser +func IsAdmin() bool { + return false +} diff --git a/util/wsproxy/client/dialer_js.go b/util/wsproxy/client/dialer_js.go new file mode 100644 index 000000000..2caeed025 --- /dev/null +++ b/util/wsproxy/client/dialer_js.go @@ -0,0 +1,171 @@ +package client + +import ( + "context" + "fmt" + "net" + "sync" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/util/wsproxy" +) + +const dialTimeout = 30 * time.Second + +// websocketConn wraps a JavaScript WebSocket to implement net.Conn +type websocketConn struct { + ws js.Value + remoteAddr string + messages chan []byte + readBuf []byte + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex +} + +func (c *websocketConn) Read(b []byte) (int, error) { + c.mu.Lock() + if len(c.readBuf) > 0 { + n := copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + c.mu.Unlock() + return n, nil + } + c.mu.Unlock() + + select { + case data := <-c.messages: + n := copy(b, data) + if n < len(data) { + c.mu.Lock() + c.readBuf = data[n:] + c.mu.Unlock() + } + return n, nil + case <-c.ctx.Done(): + return 0, c.ctx.Err() + } +} + +func (c *websocketConn) Write(b []byte) (int, error) { + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + default: + } + + uint8Array := js.Global().Get("Uint8Array").New(len(b)) + js.CopyBytesToJS(uint8Array, b) + c.ws.Call("send", uint8Array) + return len(b), nil +} + +func (c *websocketConn) Close() error { + c.cancel() + c.ws.Call("close") + return nil +} + +func (c *websocketConn) LocalAddr() net.Addr { + return nil +} + +func (c *websocketConn) RemoteAddr() net.Addr { + return stringAddr(c.remoteAddr) +} +func (c *websocketConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *websocketConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *websocketConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// stringAddr is a simple net.Addr that returns a string +type stringAddr string + +func (s stringAddr) Network() string { return "tcp" } +func (s stringAddr) String() string { return string(s) } + +// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments. +func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + scheme := "wss" + if !tlsEnabled { + scheme = "ws" + } + wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath) + + ws := js.Global().Get("WebSocket").New(wsURL) + + connCtx, connCancel := context.WithCancel(context.Background()) + conn := &websocketConn{ + ws: ws, + remoteAddr: addr, + messages: make(chan []byte, 100), + ctx: connCtx, + cancel: connCancel, + } + + ws.Set("binaryType", "arraybuffer") + + openCh := make(chan struct{}) + errorCh := make(chan error, 1) + + ws.Set("onopen", js.FuncOf(func(this js.Value, args []js.Value) any { + close(openCh) + return nil + })) + + ws.Set("onerror", js.FuncOf(func(this js.Value, args []js.Value) any { + select { + case errorCh <- wsproxy.ErrConnectionFailed: + default: + } + return nil + })) + + ws.Set("onmessage", js.FuncOf(func(this js.Value, args []js.Value) any { + event := args[0] + data := event.Get("data") + + uint8Array := js.Global().Get("Uint8Array").New(data) + length := uint8Array.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, uint8Array) + + select { + case conn.messages <- bytes: + default: + log.Warnf("gRPC WebSocket message dropped for %s - buffer full", addr) + } + return nil + })) + + ws.Set("onclose", js.FuncOf(func(this js.Value, args []js.Value) any { + conn.cancel() + return nil + })) + + select { + case <-openCh: + return conn, nil + case err := <-errorCh: + return nil, err + case <-ctx.Done(): + ws.Call("close") + return nil, ctx.Err() + case <-time.After(dialTimeout): + ws.Call("close") + return nil, wsproxy.ErrConnectionTimeout + } + }) +} diff --git a/util/wsproxy/constants.go b/util/wsproxy/constants.go new file mode 100644 index 000000000..8d117c7d9 --- /dev/null +++ b/util/wsproxy/constants.go @@ -0,0 +1,13 @@ +package wsproxy + +import "errors" + +// ProxyPath is the standard path where the WebSocket proxy is mounted on servers. +const ProxyPath = "/ws-proxy" + +// Common errors +var ( + ErrConnectionTimeout = errors.New("WebSocket connection timeout") + ErrConnectionFailed = errors.New("WebSocket connection failed") + ErrBackendUnavailable = errors.New("backend unavailable") +) diff --git a/util/wsproxy/server/metrics.go b/util/wsproxy/server/metrics.go new file mode 100644 index 000000000..dd3b96dad --- /dev/null +++ b/util/wsproxy/server/metrics.go @@ -0,0 +1,118 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +// MetricsRecorder defines the interface for recording proxy metrics +type MetricsRecorder interface { + // RecordConnection records a new connection + RecordConnection(ctx context.Context) + // RecordDisconnection records a connection closing + RecordDisconnection(ctx context.Context) + // RecordBytesTransferred records bytes transferred in a direction + RecordBytesTransferred(ctx context.Context, direction string, bytes int64) + // RecordError records an error + RecordError(ctx context.Context, errorType string) +} + +// NoOpMetricsRecorder is a no-op implementation that does nothing +type NoOpMetricsRecorder struct{} + +func (n NoOpMetricsRecorder) RecordConnection(ctx context.Context) { + // no-op +} +func (n NoOpMetricsRecorder) RecordDisconnection(ctx context.Context) { + // no-op +} +func (n NoOpMetricsRecorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) { + // no-op +} +func (n NoOpMetricsRecorder) RecordError(ctx context.Context, errorType string) { + // no-op +} + +// Recorder implements MetricsRecorder using OpenTelemetry +type Recorder struct { + activeConnections metric.Int64UpDownCounter + bytesTransferred metric.Int64Counter + errors metric.Int64Counter +} + +// NewMetricsRecorder creates a new OpenTelemetry-based metrics recorder +func NewMetricsRecorder(meter metric.Meter) (*Recorder, error) { + activeConnections, err := meter.Int64UpDownCounter( + "wsproxy_active_connections", + metric.WithDescription("Number of active WebSocket proxy connections"), + ) + if err != nil { + return nil, err + } + + bytesTransferred, err := meter.Int64Counter( + "wsproxy_bytes_transferred_total", + metric.WithDescription("Total bytes transferred through the proxy"), + ) + if err != nil { + return nil, err + } + + errors, err := meter.Int64Counter( + "wsproxy_errors_total", + metric.WithDescription("Total number of proxy errors"), + ) + if err != nil { + return nil, err + } + + return &Recorder{ + activeConnections: activeConnections, + bytesTransferred: bytesTransferred, + errors: errors, + }, nil +} + +func (o *Recorder) RecordConnection(ctx context.Context) { + o.activeConnections.Add(ctx, 1) +} + +func (o *Recorder) RecordDisconnection(ctx context.Context) { + o.activeConnections.Add(ctx, -1) +} + +func (o *Recorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) { + o.bytesTransferred.Add(ctx, bytes, metric.WithAttributes( + attribute.String("direction", direction), + )) +} + +func (o *Recorder) RecordError(ctx context.Context, errorType string) { + o.errors.Add(ctx, 1, metric.WithAttributes( + attribute.String("error_type", errorType), + )) +} + +// Option defines functional options for the Proxy +type Option func(*Config) + +// WithMetrics sets a custom metrics recorder +func WithMetrics(recorder MetricsRecorder) Option { + return func(c *Config) { + c.MetricsRecorder = recorder + } +} + +// WithOTelMeter creates and sets an OpenTelemetry metrics recorder +func WithOTelMeter(meter metric.Meter) Option { + return func(c *Config) { + if recorder, err := NewMetricsRecorder(meter); err == nil { + c.MetricsRecorder = recorder + } else { + log.Warnf("Failed to create OTel metrics recorder: %v", err) + } + } +} diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go new file mode 100644 index 000000000..977440a60 --- /dev/null +++ b/util/wsproxy/server/proxy.go @@ -0,0 +1,227 @@ +package server + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/netip" + "sync" + "time" + + "github.com/coder/websocket" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util/wsproxy" +) + +const ( + dialTimeout = 10 * time.Second + bufferSize = 32 * 1024 +) + +// Config contains the configuration for the WebSocket proxy. +type Config struct { + LocalGRPCAddr netip.AddrPort + Path string + MetricsRecorder MetricsRecorder +} + +// Proxy handles WebSocket to TCP proxying for gRPC connections. +type Proxy struct { + config Config + metrics MetricsRecorder +} + +// New creates a new WebSocket proxy instance with optional configuration +func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { + config := Config{ + LocalGRPCAddr: localGRPCAddr, + Path: wsproxy.ProxyPath, + MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op + } + + for _, opt := range opts { + opt(&config) + } + + return &Proxy{ + config: config, + metrics: config.MetricsRecorder, + } +} + +// Handler returns an http.Handler that proxies WebSocket connections to the local gRPC server. +func (p *Proxy) Handler() http.Handler { + return http.HandlerFunc(p.handleWebSocket) +} + +func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + p.metrics.RecordConnection(ctx) + defer p.metrics.RecordDisconnection(ctx) + + log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr) + acceptOptions := &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + } + + wsConn, err := websocket.Accept(w, r, acceptOptions) + if err != nil { + p.metrics.RecordError(ctx, "websocket_accept_failed") + log.Errorf("WebSocket upgrade failed from %s: %v", r.RemoteAddr, err) + return + } + defer func() { + if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil { + log.Debugf("Failed to close WebSocket: %v", err) + } + }() + + log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) + tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout) + if err != nil { + p.metrics.RecordError(ctx, "tcp_dial_failed") + log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err) + if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil { + log.Debugf("Failed to close WebSocket after connection failure: %v", err) + } + return + } + defer func() { + if err := tcpConn.Close(); err != nil { + log.Debugf("Failed to close TCP connection: %v", err) + } + }() + + log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr) + + p.proxyData(ctx, wsConn, tcpConn) +} + +func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) { + proxyCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + wg.Add(2) + + go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn) + go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn) + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + log.Tracef("Proxy data transfer completed, both goroutines terminated") + case <-proxyCtx.Done(): + log.Tracef("Proxy data transfer cancelled, forcing connection closure") + + if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil { + log.Tracef("Error closing WebSocket during cancellation: %v", err) + } + if err := tcpConn.Close(); err != nil { + log.Tracef("Error closing TCP connection during cancellation: %v", err) + } + + select { + case <-done: + log.Tracef("Goroutines terminated after forced connection closure") + case <-time.After(2 * time.Second): + log.Tracef("Goroutines did not terminate within timeout after connection closure") + } + } +} + +func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { + defer wg.Done() + defer cancel() + + for { + msgType, data, err := wsConn.Read(ctx) + if err != nil { + switch { + case ctx.Err() != nil: + log.Debugf("wsToTCP goroutine terminating due to context cancellation") + case websocket.CloseStatus(err) == websocket.StatusNormalClosure: + log.Debugf("WebSocket closed normally") + default: + p.metrics.RecordError(ctx, "websocket_read_error") + log.Errorf("WebSocket read error: %v", err) + } + return + } + + if msgType != websocket.MessageBinary { + log.Warnf("Unexpected WebSocket message type: %v", msgType) + continue + } + + if ctx.Err() != nil { + log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write") + return + } + + if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Debugf("Failed to set TCP write deadline: %v", err) + } + + n, err := tcpConn.Write(data) + if err != nil { + p.metrics.RecordError(ctx, "tcp_write_error") + log.Errorf("TCP write error: %v", err) + return + } + + p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n)) + } +} + +func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { + defer wg.Done() + defer cancel() + + buf := make([]byte, bufferSize) + for { + if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Debugf("Failed to set TCP read deadline: %v", err) + } + n, err := tcpConn.Read(buf) + + if err != nil { + if ctx.Err() != nil { + log.Tracef("tcpToWS goroutine terminating due to context cancellation") + return + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + continue + } + + if err != io.EOF { + log.Errorf("TCP read error: %v", err) + } + return + } + + if ctx.Err() != nil { + log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write") + return + } + + if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { + p.metrics.RecordError(ctx, "websocket_write_error") + log.Errorf("WebSocket write error: %v", err) + return + } + + p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n)) + } +} From 4d7e59f199ce72e8b5d84f0953cf752181e3c923 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:10:47 +0200 Subject: [PATCH 006/120] [client,signal,management] Adjust browser client ws proxy paths (#4565) --- client/grpc/dialer.go | 7 ++++--- client/grpc/dialer_generic.go | 2 +- client/grpc/dialer_js.go | 5 +++-- flow/client/client.go | 3 ++- management/internals/server/server.go | 2 +- shared/management/client/grpc.go | 3 ++- shared/signal/client/grpc.go | 3 ++- signal/cmd/run.go | 2 +- util/wsproxy/client/dialer_js.go | 5 +++-- util/wsproxy/constants.go | 9 ++++++++- 10 files changed, 27 insertions(+), 14 deletions(-) diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 7cb38fbff..54fbb002c 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -25,8 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -// CreateConnection creates a gRPC client connection with the appropriate transport options -func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +// CreateConnection creates a gRPC client connection with the appropriate transport options. +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -49,7 +50,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. connCtx, addr, transportOption, - WithCustomDialer(tlsEnabled), + WithCustomDialer(tlsEnabled, component), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index a0d6cee0b..96f347c64 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool) grpc.DialOption { +func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() diff --git a/client/grpc/dialer_js.go b/client/grpc/dialer_js.go index e132c0098..b89ec3c21 100644 --- a/client/grpc/dialer_js.go +++ b/client/grpc/dialer_js.go @@ -7,6 +7,7 @@ import ( ) // WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments. -func WithCustomDialer(tlsEnabled bool) grpc.DialOption { - return client.WithWebSocketDialer(tlsEnabled) +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { + return client.WithWebSocketDialer(tlsEnabled, component) } diff --git a/flow/client/client.go b/flow/client/client.go index 03a4accaf..318fcfe1e 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -23,6 +23,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" + "github.com/netbirdio/netbird/util/wsproxy" ) type GRPCClient struct { @@ -54,7 +55,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl } opts = append(opts, - nbgrpc.WithCustomDialer(tlsEnabled), + nbgrpc.WithCustomDialer(tlsEnabled, wsproxy.FlowComponent), grpc.WithIdleTimeout(interval*2), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/management/internals/server/server.go b/management/internals/server/server.go index ae9ac4a60..94c633fc6 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -259,7 +259,7 @@ func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Hand case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")): gRPCHandler.ServeHTTP(writer, request) - case request.URL.Path == wsproxy.ProxyPath: + case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent: wsProxy.Handler().ServeHTTP(writer, request) default: httpHandler.ServeHTTP(writer, request) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index f30e965be..076f2532b 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/util/wsproxy" ) const ConnectTimeout = 10 * time.Second @@ -52,7 +53,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 5ca0c0282..31f3372c0 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/signal/proto" + "github.com/netbirdio/netbird/util/wsproxy" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -57,7 +58,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/signal/cmd/run.go b/signal/cmd/run.go index e2a69a75b..696c44723 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -258,7 +258,7 @@ func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.URL.Path == wsproxy.ProxyPath: + case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent: wsProxy.Handler().ServeHTTP(w, r) default: grpcServer.ServeHTTP(w, r) diff --git a/util/wsproxy/client/dialer_js.go b/util/wsproxy/client/dialer_js.go index 2caeed025..bd50f51b5 100644 --- a/util/wsproxy/client/dialer_js.go +++ b/util/wsproxy/client/dialer_js.go @@ -96,13 +96,14 @@ func (s stringAddr) Network() string { return "tcp" } func (s stringAddr) String() string { return string(s) } // WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments. -func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption { +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func WithWebSocketDialer(tlsEnabled bool, component string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { scheme := "wss" if !tlsEnabled { scheme = "ws" } - wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath) + wsURL := fmt.Sprintf("%s://%s%s%s", scheme, addr, wsproxy.ProxyPath, component) ws := js.Global().Get("WebSocket").New(wsURL) diff --git a/util/wsproxy/constants.go b/util/wsproxy/constants.go index 8d117c7d9..a31c0fbc8 100644 --- a/util/wsproxy/constants.go +++ b/util/wsproxy/constants.go @@ -2,9 +2,16 @@ package wsproxy import "errors" -// ProxyPath is the standard path where the WebSocket proxy is mounted on servers. +// ProxyPath is the base path where the WebSocket proxy is mounted on servers. const ProxyPath = "/ws-proxy" +// Component paths that are appended to ProxyPath +const ( + ManagementComponent = "/management" + SignalComponent = "/signal" + FlowComponent = "/flow" +) + // Common errors var ( ErrConnectionTimeout = errors.New("WebSocket connection timeout") From b85045e723b251e30d3f22e14da2bce817b7767e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 1 Oct 2025 19:52:54 -0300 Subject: [PATCH 007/120] [misc] Update infra scripts with ws proxy for browser client (#4566) * Update infra scripts with ws proxy for browser client * add ws proxy to nginx tmpl --- infrastructure_files/docker-compose.yml.tmpl.traefik | 7 ++++++- infrastructure_files/getting-started-with-zitadel.sh | 2 ++ infrastructure_files/nginx.tmpl.conf | 8 ++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 08749a4f7..fb01e6867 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -45,6 +45,9 @@ services: - $SIGNAL_VOLUMENAME:/var/lib/netbird labels: - traefik.enable=true + - traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`) + - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal + - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c @@ -87,7 +90,9 @@ services: - traefik.http.routers.netbird-api.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/api`) - traefik.http.routers.netbird-api.service=netbird-api - traefik.http.services.netbird-api.loadbalancer.server.port=33073 - + - traefik.http.routers.netbird-wsproxy-mgmt.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/management`) + - traefik.http.routers.netbird-wsproxy-mgmt.service=netbird-wsproxy-mgmt + - traefik.http.services.netbird-wsproxy-mgmt.loadbalancer.server.port=33073 - traefik.http.routers.netbird-management.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/management.ManagementService/`) - traefik.http.routers.netbird-management.service=netbird-management - traefik.http.services.netbird-management.loadbalancer.server.port=33073 diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index cfec1000e..be9662345 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -621,9 +621,11 @@ renderCaddyfile() { # relay reverse_proxy /relay* relay:80 # Signal + reverse_proxy /ws-proxy/signal* signal:10000 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 # Management reverse_proxy /api/* management:80 + reverse_proxy /ws-proxy/management* management:80 reverse_proxy /management.ManagementService/* h2c://management:80 # Zitadel reverse_proxy /zitadel.admin.v1.AdminService/* h2c://zitadel:8080 diff --git a/infrastructure_files/nginx.tmpl.conf b/infrastructure_files/nginx.tmpl.conf index f7fa4a9d0..fbd892c29 100644 --- a/infrastructure_files/nginx.tmpl.conf +++ b/infrastructure_files/nginx.tmpl.conf @@ -52,6 +52,10 @@ server { location / { proxy_pass http://dashboard; } + # Proxy Signal wsproxy endpoint + location /ws-proxy/signal { + proxy_pass http://signal; + } # Proxy Signal location /signalexchange.SignalExchange/ { grpc_pass grpc://signal; @@ -64,6 +68,10 @@ server { location /api { proxy_pass http://management; } + # Proxy Management wsproxy endpoint + location /ws-proxy/management { + proxy_pass http://management; + } # Proxy Management grpc endpoint location /management.ManagementService/ { grpc_pass grpc://management; From 9bcd3ebed4c6e9144b939a3ed6555e9d6a0f0be4 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 2 Oct 2025 06:02:10 +0700 Subject: [PATCH 008/120] [management,client] Make DNS ForwarderPort Configurable & Change Well Known Port (#4479) makes the DNS forwarder port configurable in the management and client components, while changing the well-known port from 5454 to 22054. The change includes version-aware port assignment to ensure backward compatibility. - Adds a configurable `ForwarderPort` field to the DNS configuration protocol - Implements version-based port computation that returns the new port (22054) only when all peers support version 0.59.0 or newer - Updates the client to dynamically restart the DNS forwarder when the port changes --- client/internal/dnsfwd/manager.go | 33 +- client/internal/engine.go | 35 +- client/internal/netflow/logger/logger.go | 2 +- .../routemanager/dnsinterceptor/handler.go | 4 +- go.mod | 2 +- management/server/dns.go | 44 ++- management/server/dns_test.go | 132 +++++++- management/server/grpcserver.go | 23 +- management/server/peer.go | 21 +- management/server/peer/peer.go | 11 +- management/server/peer_test.go | 3 +- shared/management/proto/management.pb.go | 304 +++++++++--------- shared/management/proto/management.proto | 1 + 13 files changed, 416 insertions(+), 199 deletions(-) diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index bf2ee839b..5c7a3fbdd 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "sync" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -11,14 +12,18 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" +) + +var ( + // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also + listenPort uint16 = 5353 + listenPortMu sync.RWMutex ) const ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - ListenPort = 5353 - dnsTTL = 60 //seconds + dnsTTL = 60 //seconds ) // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. @@ -35,12 +40,20 @@ type Manager struct { fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder + port uint16 } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { +func ListenPort() uint16 { + listenPortMu.RLock() + defer listenPortMu.RUnlock() + return listenPort +} + +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, + port: port, } } @@ -54,7 +67,13 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder) + if m.port > 0 { + listenPortMu.Lock() + listenPort = m.port + listenPortMu.Unlock() + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -94,7 +113,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort}, + Values: []uint16{ListenPort()}, } if m.firewall == nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 3fa0b58a8..646e059d4 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -202,6 +202,9 @@ type Engine struct { // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitorWg sync.WaitGroup + + // dns forwarder port + dnsFwdPort uint16 } // Peer is an instance of the Connection Peer @@ -244,6 +247,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), + dnsFwdPort: dnsfwd.ListenPort(), } sm := profilemanager.NewServiceManager("") @@ -1080,7 +1084,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1839,6 +1843,7 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, + forwarderPort uint16, ) { if e.config.DisableServerRoutes { return @@ -1855,16 +1860,20 @@ func (e *Engine) updateDNSForwarder( } if len(fwdEntries) > 0 { - if e.dnsForwardMgr == nil { - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - + switch { + case e.dnsForwardMgr == nil: + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - } else { + case e.dnsFwdPort != forwarderPort: + log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) + e.restartDnsFwd(fwdEntries, forwarderPort) + e.dnsFwdPort = forwarderPort + + default: e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { @@ -1874,6 +1883,20 @@ func (e *Engine) updateDNSForwarder( } e.dnsForwardMgr = nil } + +} + +func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { + log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) + // stop and start the forwarder to apply the new port + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { + log.Errorf("failed to start DNS forward: %v", err) + e.dnsForwardMgr = nil + } } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index e28fdf2f4..899faf108 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -138,7 +138,7 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { return false } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 9069cdcc5..47c2ffcda 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -24,8 +24,8 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) const dnsTimeout = 8 * time.Second @@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/go.mod b/go.mod index c4b629993..a1560b409 100644 --- a/go.mod +++ b/go.mod @@ -102,6 +102,7 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a + golang.org/x/mod v0.25.0 golang.org/x/net v0.42.0 golang.org/x/oauth2 v0.28.0 golang.org/x/sync v0.16.0 @@ -243,7 +244,6 @@ require ( go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect - golang.org/x/mod v0.25.0 // indirect golang.org/x/text v0.27.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.34.0 // indirect diff --git a/management/server/dns.go b/management/server/dns.go index 6b73dbd0e..534f43ec6 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -6,9 +6,11 @@ import ( "sync" log "github.com/sirupsen/logrus" + "golang.org/x/mod/semver" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" @@ -18,6 +20,13 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +const ( + dnsForwarderPort = 22054 + oldForwarderPort = 5353 +) + +const dnsForwarderPortMinVersion = "v0.59.0" + // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { NameServerGroups sync.Map @@ -183,12 +192,45 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID return validateGroups(settings.DisabledManagementGroups, groups) } +// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. +// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. +func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { + if len(peers) == 0 { + return oldForwarderPort + } + + reqVer := semver.Canonical(requiredVersion) + + // Check if all peers have the required version or newer + for _, peer := range peers { + + // Development version is always supported + if peer.Meta.WtVersion == "development" { + continue + } + peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) + if peerVersion == "" { + // If any peer doesn't have version info, return 0 + return oldForwarderPort + } + + // Compare versions + if semver.Compare(peerVersion, reqVer) < 0 { + return oldForwarderPort + } + } + + // All peers have the required version or newer + return dnsForwarderPort +} + // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache -func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { +func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ ServiceEnable: update.ServiceEnable, CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), + ForwarderPort: forwardPort, } for _, zone := range update.CustomZones { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index a2a2ce529..83caf74ef 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -21,7 +21,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/shared/management/status" @@ -324,13 +323,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return nil, err } - account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{ + account.NameServerGroups[dnsNSGroup1] = &nbdns.NameServerGroup{ ID: dnsNSGroup1, Name: "ns-group-1", - NameServers: []dns.NameServer{{ + NameServers: []nbdns.NameServer{{ IP: netip.MustParseAddr(savedPeer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, }}, Primary: true, Enabled: true, @@ -395,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache) + toProtocolDNSConfig(testData, cache, dnsForwarderPort) } }) @@ -403,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache) + toProtocolDNSConfig(testData, cache, dnsForwarderPort) } }) } @@ -456,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache) + result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache) + result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache) + result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -483,6 +482,107 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } } +func TestComputeForwarderPort(t *testing.T) { + // Test with empty peers list + peers := []*nbpeer.Peer{} + result := computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) + } + + // Test with peers that have old versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.26.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) + } + + // Test with peers that have new versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != dnsForwarderPort { + t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) + } + + // Test with peers that have mixed versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) + } + + // Test with peers that have empty version + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) + } + + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result == oldForwarderPort { + t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) + } + + // Test with peers that have unknown version string + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "unknown", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) + } +} + func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) @@ -534,10 +634,10 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }() _, err = manager.CreateNameServerGroup( - context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{ + context.Background(), account.Id, "ns-group", "ns-group", []nbdns.NameServer{{ IP: netip.MustParseAddr(peer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, }}, []string{"groupB"}, true, []string{}, true, userID, false, @@ -567,10 +667,10 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }() _, err = manager.CreateNameServerGroup( - context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + context.Background(), account.Id, "ns-group-1", "ns-group-1", []nbdns.NameServer{{ IP: netip.MustParseAddr(peer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, }}, []string{"groupA"}, true, []string{}, true, userID, false, diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 1177eefff..12b59b691 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -715,13 +715,13 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } } -func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), - DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), }, Checks: toProtocolChecks(ctx, checks), } @@ -732,11 +732,11 @@ func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.P response.NetworkMap.PeerConfig = response.PeerConfig - allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName) - response.RemotePeers = allPeers - response.NetworkMap.RemotePeers = allPeers - response.RemotePeersIsEmpty = len(allPeers) == 0 + remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) + remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) + response.RemotePeers = remotePeers + response.NetworkMap.RemotePeers = remotePeers + response.RemotePeersIsEmpty = len(remotePeers) == 0 response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) @@ -808,7 +808,14 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups) + // Get all peers in the account for forwarder port computation + allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "") + if err != nil { + return fmt.Errorf("get account peers: %w", err) + } + dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion) + + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/peer.go b/management/server/peer.go index ea4617af0..4cf5d1e46 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -729,7 +729,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool - var updated bool + var updated, versionChanged bool var err error var postureChecks []*posture.Checks @@ -769,7 +769,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return err } - updated = peer.UpdateMetaIfNew(sync.Meta) + updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta) if updated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) @@ -788,7 +788,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil, nil, nil, err } - if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { + if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -880,7 +880,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return err } - isPeerUpdated = peer.UpdateMetaIfNew(login.Meta) + isPeerUpdated, _ = peer.UpdateMetaIfNew(login.Meta) if isPeerUpdated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true @@ -1229,6 +1229,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) + for _, peer := range account.Peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) @@ -1265,7 +1267,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account peerGroups := account.GetPeerGroups(p.ID) start = time.Now() - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups)) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) @@ -1376,7 +1378,9 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI } peerGroups := account.GetPeerGroups(peerId) - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups)) + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) + + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) } @@ -1549,6 +1553,8 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return nil, err } + dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion) + for _, peer := range peers { if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { return nil, fmt.Errorf("failed to remove peer %s from groups", peer.ID) @@ -1592,6 +1598,9 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto RemotePeersIsEmpty: true, FirewallRules: []*proto.FirewallRule{}, FirewallRulesIsEmpty: true, + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, }, }, NetworkMap: &types.NetworkMap{}, diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index f89f10dac..a898fd782 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -233,21 +233,24 @@ func (p *Peer) Copy() *Peer { // UpdateMetaIfNew updates peer's system metadata if new information is provided // returns true if meta was updated, false otherwise -func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool { +func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) { if meta.isEmpty() { - return false + return updated, versionChanged } + versionChanged = p.Meta.WtVersion != meta.WtVersion + // Avoid overwriting UIVersion if the update was triggered sole by the CLI client if meta.UIVersion == "" { meta.UIVersion = p.Meta.UIVersion } if p.Meta.isEqual(meta) { - return false + return updated, versionChanged } p.Meta = meta - return true + updated = true + return updated, versionChanged } // GetLastLogin returns the last login time of the peer. diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 734536d7b..42b3244ae 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) assert.NotNil(t, response) // assert peer config @@ -1212,6 +1212,7 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID) // assert network map DNSConfig assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable) + assert.Equal(t, int64(dnsForwarderPort), response.NetworkMap.DNSConfig.ForwarderPort) assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones)) assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups)) // assert network map DNSConfig.CustomZones diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 8381d6682..0de00ec0c 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -7,12 +7,13 @@ package proto import ( + reflect "reflect" + sync "sync" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" durationpb "google.golang.org/protobuf/types/known/durationpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" - reflect "reflect" - sync "sync" ) const ( @@ -2491,6 +2492,7 @@ type DNSConfig struct { ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"` NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"` CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"` + ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"` } func (x *DNSConfig) Reset() { @@ -2546,6 +2548,13 @@ func (x *DNSConfig) GetCustomZones() []*CustomZone { return nil } +func (x *DNSConfig) GetForwarderPort() int64 { + if x != nil { + return x.ForwarderPort + } + return 0 +} + // CustomZone represents a dns.CustomZone type CustomZone struct { state protoimpl.MessageState @@ -3721,7 +3730,7 @@ var file_management_proto_rawDesc = []byte{ 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, - 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, + 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, @@ -3732,157 +3741,160 @@ var file_management_proto_rawDesc = []byte{ 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, - 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, - 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, - 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, - 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, - 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, - 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, - 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, - 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, - 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, - 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, - 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, - 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, - 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, - 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, - 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, - 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, - 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, - 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, - 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, - 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, - 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, - 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, - 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, - 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, - 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, - 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, - 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, - 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, - 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, - 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, - 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, - 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, - 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, - 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, - 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, - 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, - 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, - 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, - 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, - 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, - 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, - 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, - 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, - 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, - 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, - 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, - 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, - 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, - 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, - 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, - 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x24, + 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, + 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, + 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, + 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, + 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, + 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, + 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, + 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, + 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, + 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, + 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, + 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, + 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, + 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, + 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, + 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, + 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, + 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, + 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, + 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, + 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, + 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, + 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, + 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, + 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, + 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, + 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, + 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, + 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, + 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, + 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, + 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, + 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, + 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, + 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, + 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, - 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, - 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, + 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, - 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, + 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, - 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, - 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, + 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, + 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index dcdd387b4..ad82d37d9 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -410,6 +410,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; + int64 ForwarderPort = 4; } // CustomZone represents a dns.CustomZone From 95794f53ce9454f18f73f50c67223ece2b50e68d Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 2 Oct 2025 17:42:25 +0700 Subject: [PATCH 009/120] [client] fix Windows NRPT Policy Path (#4572) [client] fix Windows NRPT Policy Path --- client/internal/dns/host_windows.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 0d3f033fb..a14a01f40 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -240,15 +240,17 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 for i, domain := range domains { + localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) + gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) singleDomain := []string{domain} - if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil { return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) } if r.gpo { - if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil { return i, fmt.Errorf("configure gpo DNS policy: %w", err) } } From e7b5537dcc280384470668f461bbb1f7d2f41218 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 2 Oct 2025 13:51:39 +0200 Subject: [PATCH 010/120] Add websocket paths including relay to nginx template (#4573) --- infrastructure_files/nginx.tmpl.conf | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/infrastructure_files/nginx.tmpl.conf b/infrastructure_files/nginx.tmpl.conf index fbd892c29..46cb195e7 100644 --- a/infrastructure_files/nginx.tmpl.conf +++ b/infrastructure_files/nginx.tmpl.conf @@ -20,6 +20,10 @@ upstream management { # insert the grpc+http port of your management container here server 127.0.0.1:8012; } +upstream relay { + # insert the port of your relay container here + server 127.0.0.1:33080; +} server { # HTTP server config @@ -55,6 +59,10 @@ server { # Proxy Signal wsproxy endpoint location /ws-proxy/signal { proxy_pass http://signal; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; + proxy_set_header Host $host; } # Proxy Signal location /signalexchange.SignalExchange/ { @@ -71,6 +79,10 @@ server { # Proxy Management wsproxy endpoint location /ws-proxy/management { proxy_pass http://management; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; + proxy_set_header Host $host; } # Proxy Management grpc endpoint location /management.ManagementService/ { @@ -80,6 +92,14 @@ server { grpc_send_timeout 1d; grpc_socket_keepalive on; } + # Proxy Relay + location /relay { + proxy_pass http://relay; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; + proxy_set_header Host $host; + } ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem; ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem; From 34341d95a930a252588933eada70f92c296d756f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 6 Oct 2025 15:22:02 -0300 Subject: [PATCH 011/120] Adjust signal port for websocket connections (#4594) --- infrastructure_files/docker-compose.yml.tmpl.traefik | 2 +- infrastructure_files/getting-started-with-zitadel.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index fb01e6867..196e26a66 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -47,7 +47,7 @@ services: - traefik.enable=true - traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`) - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000 + - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index be9662345..bc326cd7e 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -621,7 +621,7 @@ renderCaddyfile() { # relay reverse_proxy /relay* relay:80 # Signal - reverse_proxy /ws-proxy/signal* signal:10000 + reverse_proxy /ws-proxy/signal* signal:80 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 # Management reverse_proxy /api/* management:80 From 954f40991f3b6c399e6ce1cb20577aacca87d38e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 6 Oct 2025 21:22:19 +0200 Subject: [PATCH 012/120] [client,management,signal] Handle grpc from ws proxy internally instead of via tcp (#4593) --- client/grpc/dialer.go | 7 +- management/internals/server/server.go | 3 +- signal/cmd/run.go | 9 +- util/wsproxy/server/proxy.go | 140 ++++++++++---------------- 4 files changed, 62 insertions(+), 97 deletions(-) diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 54fbb002c..6aff53b92 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -29,7 +29,8 @@ func Backoff(ctx context.Context) backoff.BackOff { // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) - if tlsEnabled { + // for js, the outer websocket layer takes care of tls + if tlsEnabled && runtime.GOOS != "js" { certPool, err := x509.SystemCertPool() if err != nil || certPool == nil { log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) @@ -37,9 +38,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone } transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ - // for js, outer websocket layer takes care of tls verification via WithCustomDialer - InsecureSkipVerify: runtime.GOOS == "js", - RootCAs: certPool, + RootCAs: certPool, })) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 94c633fc6..ab1c2ebe7 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/http" - "net/netip" "strings" "sync" "time" @@ -252,7 +251,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) } func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { - wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) + wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter)) return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { switch { diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 696c44723..96873dee7 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -10,7 +10,6 @@ import ( "net/http" // nolint:gosec _ "net/http/pprof" - "net/netip" "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -63,10 +62,10 @@ var ( Use: "run", Short: "start NetBird Signal Server daemon", SilenceUsage: true, - PreRun: func(cmd *cobra.Command, args []string) { + PreRunE: func(cmd *cobra.Command, args []string) error { err := util.InitLog(logLevel, logFile) if err != nil { - log.Fatalf("failed initializing log %v", err) + return fmt.Errorf("failed initializing log: %w", err) } flag.Parse() @@ -87,6 +86,8 @@ var ( signalPort = 80 } } + + return nil }, RunE: func(cmd *cobra.Command, args []string) error { flag.Parse() @@ -254,7 +255,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h } func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { - wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter)) + wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go index 977440a60..8866924df 100644 --- a/util/wsproxy/server/proxy.go +++ b/util/wsproxy/server/proxy.go @@ -2,42 +2,41 @@ package server import ( "context" - "errors" "io" "net" "net/http" - "net/netip" "sync" "time" "github.com/coder/websocket" log "github.com/sirupsen/logrus" + "golang.org/x/net/http2" "github.com/netbirdio/netbird/util/wsproxy" ) const ( - dialTimeout = 10 * time.Second - bufferSize = 32 * 1024 + bufferSize = 32 * 1024 + ioTimeout = 5 * time.Second ) // Config contains the configuration for the WebSocket proxy. type Config struct { - LocalGRPCAddr netip.AddrPort + Handler http.Handler Path string MetricsRecorder MetricsRecorder } -// Proxy handles WebSocket to TCP proxying for gRPC connections. +// Proxy handles WebSocket to gRPC handler proxying. type Proxy struct { config Config metrics MetricsRecorder } // New creates a new WebSocket proxy instance with optional configuration -func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { +func New(handler http.Handler, opts ...Option) *Proxy { config := Config{ - LocalGRPCAddr: localGRPCAddr, + Handler: handler, Path: wsproxy.ProxyPath, MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op } @@ -63,7 +62,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { p.metrics.RecordConnection(ctx) defer p.metrics.RecordDisconnection(ctx) - log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr) + log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr) acceptOptions := &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, } @@ -75,71 +74,41 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { return } defer func() { - if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil { - log.Debugf("Failed to close WebSocket: %v", err) - } + _ = wsConn.Close(websocket.StatusNormalClosure, "") }() - log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) - tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout) - if err != nil { - p.metrics.RecordError(ctx, "tcp_dial_failed") - log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err) - if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil { - log.Debugf("Failed to close WebSocket after connection failure: %v", err) - } - return - } + clientConn, serverConn := net.Pipe() defer func() { - if err := tcpConn.Close(); err != nil { - log.Debugf("Failed to close TCP connection: %v", err) - } + _ = clientConn.Close() + _ = serverConn.Close() }() - log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr) + log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr) - p.proxyData(ctx, wsConn, tcpConn) + go func() { + (&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{ + Context: ctx, + Handler: p.config.Handler, + }) + }() + + p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr) } -func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) { +func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { proxyCtx, cancel := context.WithCancel(ctx) defer cancel() var wg sync.WaitGroup wg.Add(2) - go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn) - go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn) + go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr) + go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr) - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - - select { - case <-done: - log.Tracef("Proxy data transfer completed, both goroutines terminated") - case <-proxyCtx.Done(): - log.Tracef("Proxy data transfer cancelled, forcing connection closure") - - if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil { - log.Tracef("Error closing WebSocket during cancellation: %v", err) - } - if err := tcpConn.Close(); err != nil { - log.Tracef("Error closing TCP connection during cancellation: %v", err) - } - - select { - case <-done: - log.Tracef("Goroutines terminated after forced connection closure") - case <-time.After(2 * time.Second): - log.Tracef("Goroutines did not terminate within timeout after connection closure") - } - } + wg.Wait() } -func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { +func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { defer wg.Done() defer cancel() @@ -148,80 +117,77 @@ func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync if err != nil { switch { case ctx.Err() != nil: - log.Debugf("wsToTCP goroutine terminating due to context cancellation") - case websocket.CloseStatus(err) == websocket.StatusNormalClosure: - log.Debugf("WebSocket closed normally") + log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr) + case websocket.CloseStatus(err) != -1: + log.Debugf("WebSocket from %s disconnected", clientAddr) default: p.metrics.RecordError(ctx, "websocket_read_error") - log.Errorf("WebSocket read error: %v", err) + log.Debugf("WebSocket read error from %s: %v", clientAddr, err) } return } if msgType != websocket.MessageBinary { - log.Warnf("Unexpected WebSocket message type: %v", msgType) + log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType) continue } if ctx.Err() != nil { - log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write") + log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write") return } - if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { - log.Debugf("Failed to set TCP write deadline: %v", err) + if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil { + log.Debugf("Failed to set pipe write deadline: %v", err) } - n, err := tcpConn.Write(data) + n, err := pipeConn.Write(data) if err != nil { - p.metrics.RecordError(ctx, "tcp_write_error") - log.Errorf("TCP write error: %v", err) + p.metrics.RecordError(ctx, "pipe_write_error") + log.Warnf("Pipe write error for %s: %v", clientAddr, err) return } - p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n)) + p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n)) } } -func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { +func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { defer wg.Done() defer cancel() buf := make([]byte, bufferSize) for { - if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - log.Debugf("Failed to set TCP read deadline: %v", err) + if err := pipeConn.SetReadDeadline(time.Now().Add(ioTimeout)); err != nil { + log.Debugf("Failed to set pipe read deadline: %v", err) } - n, err := tcpConn.Read(buf) + n, err := pipeConn.Read(buf) if err != nil { if ctx.Err() != nil { - log.Tracef("tcpToWS goroutine terminating due to context cancellation") + log.Tracef("pipeToWS goroutine terminating due to context cancellation") return } - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - continue - } - if err != io.EOF { - log.Errorf("TCP read error: %v", err) + log.Debugf("Pipe read error for %s: %v", clientAddr, err) } return } if ctx.Err() != nil { - log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write") + log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write") return } - if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { - p.metrics.RecordError(ctx, "websocket_write_error") - log.Errorf("WebSocket write error: %v", err) - return - } + if n > 0 { + if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { + p.metrics.RecordError(ctx, "websocket_write_error") + log.Warnf("WebSocket write error for %s: %v", clientAddr, err) + return + } - p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n)) + p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n)) + } } } From 88467883fc4a14607393beaac412f573eaaa1d43 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 6 Oct 2025 22:05:48 +0200 Subject: [PATCH 013/120] [management,signal] Remove ws-proxy read deadline (#4598) --- util/wsproxy/server/proxy.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go index 8866924df..ffb622200 100644 --- a/util/wsproxy/server/proxy.go +++ b/util/wsproxy/server/proxy.go @@ -158,10 +158,6 @@ func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *syn buf := make([]byte, bufferSize) for { - if err := pipeConn.SetReadDeadline(time.Now().Add(ioTimeout)); err != nil { - log.Debugf("Failed to set pipe read deadline: %v", err) - } - n, err := pipeConn.Read(buf) if err != nil { if ctx.Err() != nil { From 4d33567888fe704fb75f0c0a74ed8a27fd131052 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 8 Oct 2025 03:12:16 +0200 Subject: [PATCH 014/120] [client] Remove endpoint address on peer disconnect, retain status for activity recording (#4228) * When a peer disconnects, remove the endpoint address to avoid sending traffic to a non-existent address, but retain the status for the activity recorder. --- client/iface/configurer/kernel_unix.go | 38 ++++++++++++++++ client/iface/configurer/usp.go | 61 ++++++++++++++++++++++++++ client/iface/device/interface.go | 1 + client/iface/iface.go | 11 +++++ client/internal/engine_test.go | 4 ++ client/internal/iface_common.go | 1 + client/internal/peer/conn.go | 6 +++ client/internal/peer/iface.go | 1 + 8 files changed, 123 insertions(+) diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index 84afc38f5..96b286175 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, return nil } +func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + // Get the existing peer to preserve its allowed IPs + existingPeer, err := c.getPeer(c.deviceName, peerKey) + if err != nil { + return fmt.Errorf("get peer: %w", err) + } + + removePeerCfg := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil { + return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err) + } + + //Re-add the peer without the endpoint but same AllowedIPs + reAddPeerCfg := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + AllowedIPs: existingPeer.AllowedIPs, + ReplaceAllowedIPs: true, + } + + if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil { + return fmt.Errorf( + `error re-adding peer %s to interface %s with allowed IPs %v: %w`, + peerKey, c.deviceName, existingPeer.AllowedIPs, err, + ) + } + + return nil +} + func (c *KernelConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index f744e0127..bc875b73c 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, return nil } +func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return fmt.Errorf("parse peer key: %w", err) + } + + ipcStr, err := c.device.IpcGet() + if err != nil { + return fmt.Errorf("get IPC config: %w", err) + } + + // Parse current status to get allowed IPs for the peer + stats, err := parseStatus(c.deviceName, ipcStr) + if err != nil { + return fmt.Errorf("parse IPC config: %w", err) + } + + var allowedIPs []net.IPNet + found := false + for _, peer := range stats.Peers { + if peer.PublicKey == peerKey { + allowedIPs = peer.AllowedIPs + found = true + break + } + } + if !found { + return fmt.Errorf("peer %s not found", peerKey) + } + + // remove the peer from the WireGuard configuration + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil { + return fmt.Errorf("failed to remove peer: %s", ipcErr) + } + + // Build the peer config + peer = wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: true, + AllowedIPs: allowedIPs, + } + + config = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + + if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil { + return fmt.Errorf("remove endpoint address: %w", err) + } + + return nil +} + func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go index 1f40b0d46..db53d9c3a 100644 --- a/client/iface/device/interface.go +++ b/client/iface/device/interface.go @@ -21,4 +21,5 @@ type WGConfigurer interface { GetStats() (map[string]configurer.WGStats, error) FullStats() (*configurer.Stats, error) LastActivities() map[string]monotime.Time + RemoveEndpointAddress(peerKey string) error } diff --git a/client/iface/iface.go b/client/iface/iface.go index 609572561..158672160 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -148,6 +148,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } +func (w *WGIface) RemoveEndpointAddress(peerKey string) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } + + log.Debugf("Removing endpoint address: %s", peerKey) + return w.configurer.RemoveEndpointAddress(peerKey) +} + // RemovePeer removes a Wireguard Peer from the interface iface func (w *WGIface) RemovePeer(peerKey string) error { w.mu.Lock() diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 344104405..2f1098100 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -105,6 +105,10 @@ type MockWGIface struct { LastActivitiesFunc func() map[string]monotime.Time } +func (m *MockWGIface) RemoveEndpointAddress(_ string) error { + return nil +} + func (m *MockWGIface) FullStats() (*configurer.Stats, error) { return nil, fmt.Errorf("not implemented") } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 690fdb7cc..98fe01912 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -28,6 +28,7 @@ type wgIfaceBase interface { UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemoveEndpointAddress(key string) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8db9e58f4..ded9aa479 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() { } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.currentConnPriority = conntype.None + if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil { + conn.Log.Errorf("failed to remove wg endpoint: %v", err) + } } changed := conn.statusICE.Get() != worker.StatusDisconnected @@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() { if conn.currentConnPriority == conntype.Relay { conn.Log.Debugf("clean up WireGuard config") conn.currentConnPriority = conntype.None + if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil { + conn.Log.Errorf("failed to remove wg endpoint: %v", err) + } } if conn.wgProxyRelay != nil { diff --git a/client/internal/peer/iface.go b/client/internal/peer/iface.go index 0bcc7a68e..678396e61 100644 --- a/client/internal/peer/iface.go +++ b/client/internal/peer/iface.go @@ -18,4 +18,5 @@ type WGIface interface { GetStats() (map[string]configurer.WGStats, error) GetProxy() wgproxy.Proxy Address() wgaddr.Address + RemoveEndpointAddress(key string) error } From 229c65ffa1d783c0f1b56070c0718897f69072b0 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Wed, 8 Oct 2025 17:42:15 +0700 Subject: [PATCH 015/120] Enhance showLoginURL to include connection status check and auto-close functionality (#4525) --- client/ui/client_ui.go | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 25d7380a9..7c2000a9d 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -1354,7 +1354,13 @@ func (s *serviceClient) updateConfig() error { } // showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL. -func (s *serviceClient) showLoginURL() { +// It also starts a background goroutine that periodically checks if the client is already connected +// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is +// also cancelled when the window is closed. +func (s *serviceClient) showLoginURL() context.CancelFunc { + + // create a cancellable context for the background check goroutine + ctx, cancel := context.WithCancel(s.ctx) resIcon := fyne.NewStaticResource("netbird.png", iconAbout) @@ -1363,6 +1369,8 @@ func (s *serviceClient) showLoginURL() { s.wLoginURL.Resize(fyne.NewSize(400, 200)) s.wLoginURL.SetIcon(resIcon) } + // ensure goroutine is cancelled when the window is closed + s.wLoginURL.SetOnClosed(func() { cancel() }) // add a description label label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.") @@ -1443,7 +1451,39 @@ func (s *serviceClient) showLoginURL() { ) s.wLoginURL.SetContent(container.NewCenter(content)) + // start a goroutine to check connection status and close the window if connected + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + return + } + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + continue + } + if status.Status == string(internal.StatusConnected) { + if s.wLoginURL != nil { + s.wLoginURL.Close() + } + return + } + } + } + }() + s.wLoginURL.Show() + + // return cancel func so callers can stop the background goroutine if desired + return cancel } func openURL(url string) error { From 768332820e1e6e211325332a1350b0b2a34429b3 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Wed, 8 Oct 2025 21:54:27 +0700 Subject: [PATCH 016/120] [client] Implement DNS query caching in DNSForwarder (#4574) implements DNS query caching in the DNSForwarder to improve performance and provide fallback responses when upstream DNS servers fail. The cache stores successful DNS query results and serves them when upstream resolution fails. - Added a new cache component to store DNS query results by domain and query type - Integrated cache storage after successful DNS resolutions - Enhanced error handling to serve cached responses as fallback when upstream DNS fails --- client/internal/dnsfwd/cache.go | 78 +++++++++++++++++ client/internal/dnsfwd/cache_test.go | 86 +++++++++++++++++++ client/internal/dnsfwd/forwarder.go | 104 +++++++++++++++++++---- client/internal/dnsfwd/forwarder_test.go | 89 +++++++++++++++++++ 4 files changed, 341 insertions(+), 16 deletions(-) create mode 100644 client/internal/dnsfwd/cache.go create mode 100644 client/internal/dnsfwd/cache_test.go diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go new file mode 100644 index 000000000..43fe2d020 --- /dev/null +++ b/client/internal/dnsfwd/cache.go @@ -0,0 +1,78 @@ +package dnsfwd + +import ( + "net/netip" + "slices" + "strings" + "sync" + + "github.com/miekg/dns" +) + +type cache struct { + mu sync.RWMutex + records map[string]*cacheEntry +} + +type cacheEntry struct { + ip4Addrs []netip.Addr + ip6Addrs []netip.Addr +} + +func newCache() *cache { + return &cache{ + records: make(map[string]*cacheEntry), + } +} + +func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + entry, exists := c.records[normalizeDomain(domain)] + if !exists { + return nil, false + } + + switch reqType { + case dns.TypeA: + return slices.Clone(entry.ip4Addrs), true + case dns.TypeAAAA: + return slices.Clone(entry.ip6Addrs), true + default: + return nil, false + } + +} + +func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) { + c.mu.Lock() + defer c.mu.Unlock() + norm := normalizeDomain(domain) + entry, exists := c.records[norm] + if !exists { + entry = &cacheEntry{} + c.records[norm] = entry + } + + switch reqType { + case dns.TypeA: + entry.ip4Addrs = slices.Clone(addrs) + case dns.TypeAAAA: + entry.ip6Addrs = slices.Clone(addrs) + } +} + +// unset removes cached entries for the given domain and request type. +func (c *cache) unset(domain string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.records, normalizeDomain(domain)) +} + +// normalizeDomain converts an input domain into a canonical form used as cache key: +// lowercase and fully-qualified (with trailing dot). +func normalizeDomain(domain string) string { + // dns.Fqdn ensures trailing dot; ToLower for consistent casing + return dns.Fqdn(strings.ToLower(domain)) +} diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go new file mode 100644 index 000000000..c23f0f31d --- /dev/null +++ b/client/internal/dnsfwd/cache_test.go @@ -0,0 +1,86 @@ +package dnsfwd + +import ( + "net/netip" + "testing" +) + +func mustAddr(t *testing.T, s string) netip.Addr { + t.Helper() + a, err := netip.ParseAddr(s) + if err != nil { + t.Fatalf("parse addr %s: %v", s, err) + } + return a +} + +func TestCacheNormalization(t *testing.T) { + c := newCache() + + // Mixed case, without trailing dot + domainInput := "ExAmPlE.CoM" + ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")} + c.set(domainInput, 1 /* dns.TypeA */, ipv4) + + // Lookup with lower, with trailing dot + if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok) + } + + // Lookup with different casing again + if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok) + } +} + +func TestCacheSeparateTypes(t *testing.T) { + c := newCache() + + domain := "test.local" + ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")} + ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")} + + c.set(domain, 1 /* A */, ipv4) + c.set(domain, 28 /* AAAA */, ipv6) + + got4, ok4 := c.get(domain, 1) + if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] { + t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4) + } + + got6, ok6 := c.get(domain, 28) + if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] { + t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6) + } +} + +func TestCacheCloneOnGetAndSet(t *testing.T) { + c := newCache() + domain := "clone.test" + + src := []netip.Addr{mustAddr(t, "8.8.8.8")} + c.set(domain, 1, src) + + // Mutate source slice; cache should be unaffected + src[0] = mustAddr(t, "9.9.9.9") + + got, ok := c.get(domain, 1) + if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" { + t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok) + } + + // Mutate returned slice; internal cache should remain unchanged + got[0] = mustAddr(t, "4.4.4.4") + got2, ok2 := c.get(domain, 1) + if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" { + t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2) + } +} + +func TestCacheMiss(t *testing.T) { + c := newCache() + if got, ok := c.get("missing.example", 1); ok || got != nil { + t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) + } +} + diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index d912919a1..7a262fa4c 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -46,6 +46,7 @@ type DNSForwarder struct { fwdEntries []*ForwarderEntry firewall firewaller resolver resolver + cache *cache } func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { @@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat firewall: firewall, statusRecorder: statusRecorder, resolver: net.DefaultResolver, + cache: newCache(), } } @@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() + // remove cache entries for domains that no longer appear + f.removeStaleCacheEntries(f.fwdEntries, entries) + f.fwdEntries = entries log.Debugf("Updated DNS forwarder with %d domains", len(entries)) } +// removeStaleCacheEntries unsets cache items for domains that were present +// in the old list but not present in the new list. +func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) { + if f.cache == nil { + return + } + + newSet := make(map[string]struct{}, len(newEntries)) + for _, e := range newEntries { + if e == nil { + continue + } + newSet[e.Domain.PunycodeString()] = struct{}{} + } + + for _, e := range oldEntries { + if e == nil { + continue + } + pattern := e.Domain.PunycodeString() + if _, ok := newSet[pattern]; !ok { + f.cache.unset(pattern) + } + } +} + func (f *DNSForwarder) Close(ctx context.Context) error { var result *multierror.Error @@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) + f.cache.set(domain, question.Qtype, ips) return resp } @@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns resp.Rcode = dns.RcodeSuccess } -// handleDNSError processes DNS lookup errors and sends an appropriate error response -func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) { +// handleDNSError processes DNS lookup errors and sends an appropriate error response. +func (f *DNSForwarder) handleDNSError( + ctx context.Context, + w dns.ResponseWriter, + question dns.Question, + resp *dns.Msg, + domain string, + err error, +) { + // Default to SERVFAIL; override below when appropriate. + resp.Rcode = dns.RcodeServerFailure + + qType := question.Qtype + qTypeName := dns.TypeToString[qType] + + // Prefer typed DNS errors; fall back to generic logging otherwise. var dnsErr *net.DNSError - - switch { - case errors.As(err, &dnsErr): - resp.Rcode = dns.RcodeServerFailure - if dnsErr.IsNotFound { - f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype) + if !errors.As(err, &dnsErr) { + log.Warnf(errResolveFailed, domain, err) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } + return + } - if dnsErr.Server != "" { - log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err) - } else { - log.Warnf(errResolveFailed, domain, err) + // NotFound: set NXDOMAIN / appropriate code via helper. + if dnsErr.IsNotFound { + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } - default: - resp.Rcode = dns.RcodeServerFailure + f.cache.set(domain, question.Qtype, nil) + return + } + + // Upstream failed but we might have a cached answer—serve it if present. + if ips, ok := f.cache.get(domain, qType); ok { + if len(ips) > 0 { + log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) + f.addIPsToResponse(resp, domain, ips) + resp.Rcode = dns.RcodeSuccess + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write cached DNS response: %v", writeErr) + } + } else { // send NXDOMAIN / appropriate code if cache is empty + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) + } + } + return + } + + // No cache. Log with or without the server field for more context. + if dnsErr.Server != "" { + log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err) + } else { log.Warnf(errResolveFailed, domain, err) } - if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write failure DNS response: %v", err) + // Write final failure response. + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } } diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 57085e19a..c1c95a2c1 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") } +// Ensures that when the first query succeeds and populates the cache, +// a subsequent upstream failure still returns a successful response from cache. +func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("example.com") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("1.2.3.4") + + // First call resolves successfully and populates cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{ip}, nil).Once() + + // Second call fails upstream; forwarder should serve from cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + // First query: populate cache + q1 := &dns.Msg{} + q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Second query: serve from cache after upstream failure + q2 := &dns.Msg{} + q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp, "expected response to be written") + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + +// Verifies that cache normalization works across casing and trailing dot variations. +func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("ExAmPlE.CoM") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("9.8.7.6") + + // Initial resolution with mixed case to populate cache + mixedQuery := "ExAmPlE.CoM" + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))). + Return([]netip.Addr{ip}, nil).Once() + + q1 := &dns.Msg{} + q1.SetQuestion(mixedQuery+".", dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Subsequent query without dot and upper case should hit cache even if upstream fails + // Forwarder lowercases and uses the question name as-is (no trailing dot here) + mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + q2 := &dns.Msg{} + q2.SetQuestion("EXAMPLE.COM", dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp) + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { // Test complex overlapping pattern scenarios mockFirewall := &MockFirewall{} From 9021bb512bf7f88157a4c06aa1d9fba6ac80587f Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 8 Oct 2025 17:14:24 +0200 Subject: [PATCH 017/120] [client] Recreate agent when receive new session id (#4564) When an ICE agent connection was in progress, new offers were being ignored. This was incorrect logic because the remote agent could be restarted at any time. In this change, whenever a new session ID is received, the ongoing handshake is closed and a new one is started. --- client/internal/peer/conn.go | 4 +- client/internal/peer/conn_test.go | 16 +++--- client/internal/peer/guard/env.go | 20 ++++++++ client/internal/peer/guard/ice_monitor.go | 16 ++++-- client/internal/peer/guard/sr_watcher.go | 2 +- client/internal/peer/handshaker.go | 51 ++++++++++++------- client/internal/peer/handshaker_listener.go | 8 +-- .../internal/peer/handshaker_listener_test.go | 2 +- client/internal/peer/worker_ice.go | 47 +++++++++-------- 9 files changed, 107 insertions(+), 59 deletions(-) create mode 100644 client/internal/peer/guard/env.go diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ded9aa479..68afe986a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) - conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) + conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) if !isForceRelayed() { - conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) + conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) } conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher) diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index c839ab147..6b47f95eb 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) { return } - onNewOffeChan := make(chan struct{}) + onNewOfferChan := make(chan struct{}) - conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { - onNewOffeChan <- struct{}{} + conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { + onNewOfferChan <- struct{}{} }) conn.OnRemoteOffer(OfferAnswer{ @@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { defer cancel() select { - case <-onNewOffeChan: + case <-onNewOfferChan: // success case <-ctx.Done(): t.Error("expected to receive a new offer notification, but timed out") @@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) { return } - onNewOffeChan := make(chan struct{}) + onNewOfferChan := make(chan struct{}) - conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { - onNewOffeChan <- struct{}{} + conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { + onNewOfferChan <- struct{}{} }) conn.OnRemoteAnswer(OfferAnswer{ @@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { defer cancel() select { - case <-onNewOffeChan: + case <-onNewOfferChan: // success case <-ctx.Done(): t.Error("expected to receive a new offer notification, but timed out") diff --git a/client/internal/peer/guard/env.go b/client/internal/peer/guard/env.go new file mode 100644 index 000000000..1ea2d21be --- /dev/null +++ b/client/internal/peer/guard/env.go @@ -0,0 +1,20 @@ +package guard + +import ( + "os" + "strconv" + "time" +) + +const ( + envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD" +) + +func GetICEMonitorPeriod() time.Duration { + if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" { + if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + return defaultCandidatesMonitorPeriod +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 09cf9ae63..0f22ee7b0 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -16,8 +16,8 @@ import ( ) const ( - candidatesMonitorPeriod = 5 * time.Minute - candidateGatheringTimeout = 5 * time.Second + defaultCandidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second ) type ICEMonitor struct { @@ -25,16 +25,19 @@ type ICEMonitor struct { iFaceDiscover stdnet.ExternalIFaceDiscover iceConfig icemaker.Config + tickerPeriod time.Duration currentCandidatesAddress []string candidatesMu sync.Mutex } -func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { +func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor { + log.Debugf("prepare ICE monitor with period: %s", period) cm := &ICEMonitor{ ReconnectCh: make(chan struct{}, 1), iFaceDiscover: iFaceDiscover, iceConfig: config, + tickerPeriod: period, } return cm } @@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { return } - ticker := time.NewTicker(candidatesMonitorPeriod) + // Initial check to populate the candidates for later comparison + if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil { + log.Warnf("Failed to check initial ICE candidates: %v", err) + } + + ticker := time.NewTicker(cm.tickerPeriod) defer ticker.Stop() for { diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 90e45426f..686430752 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -51,7 +51,7 @@ func (w *SRWatcher) Start() { ctx, cancel := context.WithCancel(context.Background()) w.cancelIceMonitor = cancel - iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) go iceMonitor.Start(ctx, w.onICEChanged) w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 42eaea683..aff26f847 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -44,13 +44,19 @@ type OfferAnswer struct { } type Handshaker struct { - mu sync.Mutex - log *log.Entry - config ConnConfig - signaler *Signaler - ice *WorkerICE - relay *WorkerRelay - onNewOfferListeners []*OfferListener + mu sync.Mutex + log *log.Entry + config ConnConfig + signaler *Signaler + ice *WorkerICE + relay *WorkerRelay + // relayListener is not blocking because the listener is using a goroutine to process the messages + // and it will only keep the latest message if multiple offers are received in a short time + // this is to avoid blocking the handshaker if the listener is doing some heavy processing + // and also to avoid processing old offers if multiple offers are received in a short time + // the listener will always process the latest offer + relayListener *AsyncOfferListener + iceListener func(remoteOfferAnswer *OfferAnswer) // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer @@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W } } -func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { - l := NewOfferListener(offer) - h.onNewOfferListeners = append(h.onNewOfferListeners, l) +func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { + h.relayListener = NewAsyncOfferListener(offer) +} + +func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) { + h.iceListener = offer } func (h *Handshaker) Listen(ctx context.Context) { for { select { case remoteOfferAnswer := <-h.remoteOffersCh: - // received confirmation from the remote peer -> ready to proceed + h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + if h.relayListener != nil { + h.relayListener.Notify(&remoteOfferAnswer) + } + + if h.iceListener != nil { + h.iceListener(&remoteOfferAnswer) + } + if err := h.sendAnswer(); err != nil { h.log.Errorf("failed to send remote offer confirmation: %s", err) continue } - for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) - } - h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) - for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) + if h.relayListener != nil { + h.relayListener.Notify(&remoteOfferAnswer) + } + + if h.iceListener != nil { + h.iceListener(&remoteOfferAnswer) } case <-ctx.Done(): h.log.Infof("stop listening for remote offers and answers") diff --git a/client/internal/peer/handshaker_listener.go b/client/internal/peer/handshaker_listener.go index e2d3f3f38..772e2777f 100644 --- a/client/internal/peer/handshaker_listener.go +++ b/client/internal/peer/handshaker_listener.go @@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string { return oa.SessionID.String() } -type OfferListener struct { +type AsyncOfferListener struct { fn callbackFunc running bool latest *OfferAnswer mu sync.Mutex } -func NewOfferListener(fn callbackFunc) *OfferListener { - return &OfferListener{ +func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener { + return &AsyncOfferListener{ fn: fn, } } -func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { +func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) { o.mu.Lock() defer o.mu.Unlock() diff --git a/client/internal/peer/handshaker_listener_test.go b/client/internal/peer/handshaker_listener_test.go index 8363741a5..1a7156d10 100644 --- a/client/internal/peer/handshaker_listener_test.go +++ b/client/internal/peer/handshaker_listener_test.go @@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) { runChan <- struct{}{} } - hl := NewOfferListener(longRunningFn) + hl := NewAsyncOfferListener(longRunningFn) hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index eb886a4d3..3675f0157 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn * func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString()) w.muxAgent.Lock() + defer w.muxAgent.Unlock() - if w.agentConnecting { - w.log.Debugf("agent connection is in progress, skipping the offer") - w.muxAgent.Unlock() - return - } - - if w.agent != nil { + if w.agent != nil || w.agentConnecting { // backward compatibility with old clients that do not send session ID if remoteOfferAnswer.SessionID == nil { w.log.Debugf("agent already exists, skipping the offer") - w.muxAgent.Unlock() return } if w.remoteSessionID == *remoteOfferAnswer.SessionID { w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString()) - w.muxAgent.Unlock() return } w.log.Debugf("agent already exists, recreate the connection") @@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } + + sessionID, err := NewICESessionID() + if err != nil { + w.log.Errorf("failed to create new session ID: %s", err) + } + w.sessionID = sessionID w.agent = nil } @@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { preferredCandidateTypes = icemaker.CandidateTypes() } - w.log.Debugf("recreate ICE agent") + if remoteOfferAnswer.SessionID != nil { + w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID) + } dialerCtx, dialerCancel := context.WithCancel(w.ctx) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes) if err != nil { w.log.Errorf("failed to recreate ICE Agent: %s", err) - w.muxAgent.Unlock() return } w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true - w.muxAgent.Unlock() + if remoteOfferAnswer.SessionID != nil { + w.remoteSessionID = *remoteOfferAnswer.SessionID + } else { + w.remoteSessionID = "" + } go w.connect(dialerCtx, agent, remoteOfferAnswer) } @@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent w.muxAgent.Lock() w.agentConnecting = false w.lastSuccess = time.Now() - if remoteOfferAnswer.SessionID != nil { - w.remoteSessionID = *remoteOfferAnswer.SessionID - } w.muxAgent.Unlock() // todo: the potential problem is a race between the onConnectionStateChange @@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C } w.muxAgent.Lock() - // todo review does it make sense to generate new session ID all the time when w.agent==agent - sessionID, err := NewICESessionID() - if err != nil { - w.log.Errorf("failed to create new session ID: %s", err) - } - w.sessionID = sessionID if w.agent == agent { + // consider to remove from here and move to the OnNewOffer + sessionID, err := NewICESessionID() + if err != nil { + w.log.Errorf("failed to create new session ID: %s", err) + } + w.sessionID = sessionID w.agent = nil w.agentConnecting = false + w.remoteSessionID = "" } w.muxAgent.Unlock() } @@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // notify the conn.onICEStateDisconnected changes to update the current used priority + w.closeAgent(agent, dialerCancel) + if w.lastKnownState == ice.ConnectionStateConnected { w.lastKnownState = ice.ConnectionStateDisconnected w.conn.onICEStateDisconnected() } - w.closeAgent(agent, dialerCancel) default: return } From 654aa9581d15a7e2b2332f7f6ee395eeac40449c Mon Sep 17 00:00:00 2001 From: Ashley Date: Wed, 8 Oct 2025 20:27:32 +0100 Subject: [PATCH 018/120] [client,gui] Update url_windows.go to offer arm64 executable download (#4586) --- version/url_windows.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/version/url_windows.go b/version/url_windows.go index f2055b109..14fdb7ae6 100644 --- a/version/url_windows.go +++ b/version/url_windows.go @@ -1,9 +1,13 @@ package version -import "golang.org/x/sys/windows/registry" +import ( + "golang.org/x/sys/windows/registry" + "runtime" +) const ( urlWinExe = "https://pkgs.netbird.io/windows/x64" + urlWinExeArm = "https://pkgs.netbird.io/windows/arm64" ) var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird" @@ -11,9 +15,14 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne // DownloadUrl return with the proper download link func DownloadUrl() string { _, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE) - if err == nil { - return urlWinExe - } else { + if err != nil { return downloadURL } + + url := urlWinExe + if runtime.GOARCH == "arm64" { + url = urlWinExeArm + } + + return url } From 4e03f708a4cf9461f149b0227cbefae9abbad29e Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 9 Oct 2025 17:39:02 +0300 Subject: [PATCH 019/120] fix dns forwarder port update (#4613) fix dns forwarder port update (#4613) --- client/internal/dnsfwd/manager.go | 16 +++++++--------- client/internal/engine.go | 8 ++++++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 5c7a3fbdd..a3a4ba40f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -40,7 +40,6 @@ type Manager struct { fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder - port uint16 } func ListenPort() uint16 { @@ -49,11 +48,16 @@ func ListenPort() uint16 { return listenPort } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { +func SetListenPort(port uint16) { + listenPortMu.Lock() + listenPort = port + listenPortMu.Unlock() +} + +func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, - port: port, } } @@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.port > 0 { - listenPortMu.Lock() - listenPort = m.port - listenPortMu.Unlock() - } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 646e059d4..bebf04f6c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1849,6 +1849,10 @@ func (e *Engine) updateDNSForwarder( return } + if forwarderPort > 0 { + dnsfwd.SetListenPort(forwarderPort) + } + if !enabled { if e.dnsForwardMgr == nil { return @@ -1862,7 +1866,7 @@ func (e *Engine) updateDNSForwarder( if len(fwdEntries) > 0 { switch { case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil @@ -1892,7 +1896,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { log.Errorf("failed to stop DNS forward: %v", err) } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil From d35a845dbd336206f2b306384f1faca15cd4bd03 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 9 Oct 2025 22:18:00 +0300 Subject: [PATCH 020/120] [management] sync all other peers on peer add/remove (#4614) --- management/server/peer.go | 27 ++------------------------- management/server/peer_test.go | 4 ++-- 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 4cf5d1e46..218edf299 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -350,7 +350,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } var peer *nbpeer.Peer - var updateAccountPeers bool var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -363,11 +362,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID) - if err != nil { - return err - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -387,7 +381,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if updateAccountPeers && userID != activity.SystemInitiator { + if userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -684,11 +678,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) } - updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) - if err != nil { - updateAccountPeers = true - } - if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } @@ -701,9 +690,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if updateAccountPeers { - am.BufferUpdateAccountPeers(ctx, accountID) - } + am.BufferUpdateAccountPeers(ctx, accountID) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } @@ -1527,16 +1514,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } -// IsPeerInActiveGroup checks if the given peer is part of a group that is used -// in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { - peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) - if err != nil { - return false, err - } - return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction -} - // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 42b3244ae..fd795b926 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) // + peerShouldReceiveUpdate(t, updMsg) // close(done) }() @@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() From bedd3cabc994d6bb9903a95fc6753331c987c25c Mon Sep 17 00:00:00 2001 From: Kostya Leschenko Date: Fri, 10 Oct 2025 16:24:24 +0300 Subject: [PATCH 021/120] [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) --- client/internal/dns/systemd_linux.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 0e8a53a63..d9854c033 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -31,6 +31,7 @@ const ( systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC" + systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS" systemdDbusResolvConfModeForeign = "foreign" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" @@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana log.Warnf("failed to set DNSSEC to 'no': %v", err) } + // We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off + if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil { + log.Warnf("failed to set DNSOverTLS to 'no': %v", err) + } + var ( searchDomains []string matchDomains []string From 5151f19d29514436f65fe40faebf4c89da06133d Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 10 Oct 2025 16:15:51 +0200 Subject: [PATCH 022/120] [management] pass temporary flag to validator (#4599) --- go.mod | 2 +- go.sum | 4 ++-- management/server/integrated_validator.go | 2 +- .../server/integrations/integrated_validator/interface.go | 4 ++-- management/server/peer.go | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index a1560b409..31b45e881 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 13838b82d..6b0b298a7 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 21f11bfce..251c04273 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } -func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { +func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index ce632d567..be05c2527 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -3,16 +3,16 @@ package integrated_validator import ( "context" - "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) - PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer + PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error diff --git a/management/server/peer.go b/management/server/peer.go index 218edf299..30b7073ef 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -578,7 +578,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary) network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { From 0d2e67983ad2de4fc301514deff3409879819520 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 10 Oct 2025 14:16:48 -0300 Subject: [PATCH 023/120] [misc] Add service definition for netbird-signal (#4620) --- infrastructure_files/docker-compose.yml.tmpl.traefik | 1 + 1 file changed, 1 insertion(+) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 196e26a66..0010974c5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -49,6 +49,7 @@ services: - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) + - traefik.http.routers.netbird-signal.service=netbird-signal - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c From 000e99e7f3f5c18741ee3412ba113c32c6ce3aa7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 13 Oct 2025 17:50:16 +0200 Subject: [PATCH 024/120] [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) --- client/wasm/internal/rdp/cert_validation.go | 15 ++- client/wasm/internal/rdp/rdcleanpath.go | 93 ++++++++++++-- .../wasm/internal/rdp/rdcleanpath_handlers.go | 119 +++++++++--------- 3 files changed, 152 insertions(+), 75 deletions(-) diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go index 4a23a4bc8..1678c3996 100644 --- a/client/wasm/internal/rdp/cert_validation.go +++ b/client/wasm/internal/rdp/cert_validation.go @@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert } } -func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { - return &tls.Config{ +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config { + config := &tls.Config{ InsecureSkipVerify: true, // We'll validate manually after handshake VerifyConnection: func(cs tls.ConnectionState) error { var certChain [][]byte @@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl return nil }, } + + // CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3 + if requiresCredSSP { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS12 + } else { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS13 + } + + return config } diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 8062a05cc..16bf63bb9 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -6,11 +6,13 @@ import ( "context" "crypto/tls" "encoding/asn1" + "errors" "fmt" "io" "net" "sync" "syscall/js" + "time" log "github.com/sirupsen/logrus" ) @@ -19,18 +21,34 @@ const ( RDCleanPathVersion = 3390 RDCleanPathProxyHost = "rdcleanpath.proxy.local" RDCleanPathProxyScheme = "ws" + + rdpDialTimeout = 15 * time.Second + + GeneralErrorCode = 1 + WSAETimedOut = 10060 + WSAEConnRefused = 10061 + WSAEConnAborted = 10053 + WSAEConnReset = 10054 + WSAEGenericError = 10050 ) type RDCleanPathPDU struct { - Version int64 `asn1:"tag:0,explicit"` - Error []byte `asn1:"tag:1,explicit,optional"` - Destination string `asn1:"utf8,tag:2,explicit,optional"` - ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` - ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` - PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` - X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` - ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` - ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` + Version int64 `asn1:"tag:0,explicit"` + Error RDCleanPathErr `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathErr struct { + ErrorCode int16 `asn1:"tag:0,explicit"` + HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"` + WSALastError int16 `asn1:"tag:2,explicit,optional"` + TLSAlertCode int8 `asn1:"tag:3,explicit,optional"` } type RDCleanPathProxy struct { @@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] destination := conn.destination log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } conn.rdpConn = rdpConn @@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] _, err = rdpConn.Write(firstPacket) if err != nil { log.Errorf("Failed to write first packet: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] n, err := rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { conn.wsHandlers.Call("send", uint8Array.Get("buffer")) } } + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + p.sendToWebSocket(conn, data) +} + +func errorToWSACode(err error) int16 { + if err == nil { + return WSAEGenericError + } + var netErr *net.OpError + if errors.As(err, &netErr) && netErr.Timeout() { + return WSAETimedOut + } + if errors.Is(err, context.DeadlineExceeded) { + return WSAETimedOut + } + if errors.Is(err, context.Canceled) { + return WSAEConnAborted + } + if errors.Is(err, io.EOF) { + return WSAEConnReset + } + return WSAEGenericError +} + +func newWSAError(err error) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + WSALastError: errorToWSACode(err), + }, + } +} + +func newHTTPError(statusCode int16) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + HTTPStatusCode: statusCode, + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go index 010efa5ea..97bb46338 100644 --- a/client/wasm/internal/rdp/rdcleanpath_handlers.go +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -3,6 +3,7 @@ package rdp import ( + "context" "crypto/tls" "encoding/asn1" "io" @@ -11,11 +12,17 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP) + protocolSSL = 0x00000001 + protocolHybridEx = 0x00000008 +) + func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) if pdu.Version != RDCleanPathVersion { - p.sendRDCleanPathError(conn, "Unsupported version") + p.sendRDCleanPathError(conn, newHTTPError(400)) return } @@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl destination = pdu.Destination } - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) - p.sendRDCleanPathError(conn, "Connection failed") + p.sendRDCleanPathError(conn, newWSAError(err)) p.cleanupConnection(conn) return } @@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl p.setupTLSConnection(conn, pdu) } +// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required. +// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags. +// Returns (requiresTLS12, selectedProtocol, detectionSuccessful). +func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) { + const minResponseLength = 19 + + if len(x224Response) < minResponseLength { + return false, 0, false + } + + // Per X.224 specification: + // x224Response[0] == 0x03: Length of X.224 header (3 bytes) + // x224Response[5] == 0xD0: X.224 Data TPDU code + if x224Response[0] != 0x03 || x224Response[5] != 0xD0 { + return false, 0, false + } + + if x224Response[11] == 0x02 { + flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 | + uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24 + + hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0 + return hasNLA, flags, true + } + + return false, 0, false +} + func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { var x224Response []byte if len(pdu.X224ConnectionPDU) > 0 { @@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) if err != nil { log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean n, err := conn.rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") + p.sendRDCleanPathError(conn, newWSAError(err)) return } x224Response = response[:n] log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) } - tlsConfig := p.getTLSConfigWithValidation(conn) + requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response) + if detected { + if requiresCredSSP { + log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol) + } else { + log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol) + } + } else { + log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3") + } + + tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP) tlsConn := tls.Client(conn.rdpConn, tlsConfig) conn.tlsConn = tlsConn if err := tlsConn.Handshake(); err != nil { log.Errorf("TLS handshake failed: %v", err) - p.sendRDCleanPathError(conn, "TLS handshake failed") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean p.cleanupConnection(conn) } -func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { - if len(pdu.X224ConnectionPDU) > 0 { - log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) - _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) - if err != nil { - log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") - return - } - - response := make([]byte, 1024) - n, err := conn.rdpConn.Read(response) - if err != nil { - log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") - return - } - - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - X224ConnectionPDU: response[:n], - ServerAddr: conn.destination, - } - - p.sendRDCleanPathPDU(conn, responsePDU) - } else { - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - ServerAddr: conn.destination, - } - p.sendRDCleanPathPDU(conn, responsePDU) - } - - go p.forwardConnToWS(conn, conn.rdpConn, "TCP") - go p.forwardWSToConn(conn, conn.rdpConn, "TCP") - - <-conn.ctx.Done() - log.Debug("TCP connection context done, cleaning up") - p.cleanupConnection(conn) -} - func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { data, err := asn1.Marshal(pdu) if err != nil { @@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean p.sendToWebSocket(conn, data) } -func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { - pdu := RDCleanPathPDU{ - Version: RDCleanPathVersion, - Error: []byte(errorMsg), - } - - data, err := asn1.Marshal(pdu) - if err != nil { - log.Errorf("Failed to marshal error PDU: %v", err) - return - } - - p.sendToWebSocket(conn, data) -} - func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { msgChan := make(chan []byte) errChan := make(chan error) From bb37dc89cef3d8338277503ff5e0bbfecb76ffca Mon Sep 17 00:00:00 2001 From: John Conley <8932043+jfrconley@users.noreply.github.com> Date: Thu, 16 Oct 2025 01:46:29 -0700 Subject: [PATCH 025/120] [management] feat: Basic PocketID IDP integration (#4529) --- management/server/idp/auth0_test.go | 8 +- management/server/idp/idp.go | 6 + management/server/idp/pocketid.go | 384 +++++++++++++++++++++++++ management/server/idp/pocketid_test.go | 138 +++++++++ 4 files changed, 533 insertions(+), 3 deletions(-) create mode 100644 management/server/idp/pocketid.go create mode 100644 management/server/idp/pocketid_test.go diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index 66c16870b..bc352f117 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -26,9 +26,11 @@ type mockHTTPClient struct { } func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { - body, err := io.ReadAll(req.Body) - if err == nil { - c.reqBody = string(body) + if req.Body != nil { + body, err := io.ReadAll(req.Body) + if err == nil { + c.reqBody = string(body) + } } return &http.Response{ StatusCode: c.code, diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 51f99b3b7..f06e57196 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr APIToken: config.ExtraConfig["ApiToken"], } return NewJumpCloudManager(jumpcloudConfig, appMetrics) + case "pocketid": + pocketidConfig := PocketIdClientConfig{ + APIToken: config.ExtraConfig["ApiToken"], + ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"], + } + return NewPocketIdManager(pocketidConfig, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go new file mode 100644 index 000000000..38a5cc67f --- /dev/null +++ b/management/server/idp/pocketid.go @@ -0,0 +1,384 @@ +package idp + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type PocketIdManager struct { + managementEndpoint string + apiToken string + httpClient ManagerHTTPClient + credentials ManagerCredentials + helper ManagerHelper + appMetrics telemetry.AppMetrics +} + +type pocketIdCustomClaimDto struct { + Key string `json:"key"` + Value string `json:"value"` +} + +type pocketIdUserDto struct { + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + Disabled bool `json:"disabled"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + ID string `json:"id"` + IsAdmin bool `json:"isAdmin"` + LastName string `json:"lastName"` + LdapID string `json:"ldapId"` + Locale string `json:"locale"` + UserGroups []pocketIdUserGroupDto `json:"userGroups"` + Username string `json:"username"` +} + +type pocketIdUserCreateDto struct { + Disabled bool `json:"disabled,omitempty"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + IsAdmin bool `json:"isAdmin,omitempty"` + LastName string `json:"lastName,omitempty"` + Locale string `json:"locale,omitempty"` + Username string `json:"username"` +} + +type pocketIdPaginatedUserDto struct { + Data []pocketIdUserDto `json:"data"` + Pagination pocketIdPaginationDto `json:"pagination"` +} + +type pocketIdPaginationDto struct { + CurrentPage int `json:"currentPage"` + ItemsPerPage int `json:"itemsPerPage"` + TotalItems int `json:"totalItems"` + TotalPages int `json:"totalPages"` +} + +func (p *pocketIdUserDto) userData() *UserData { + return &UserData{ + Email: p.Email, + Name: p.DisplayName, + ID: p.ID, + AppMetadata: AppMetadata{}, + } +} + +type pocketIdUserGroupDto struct { + CreatedAt string `json:"createdAt"` + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + FriendlyName string `json:"friendlyName"` + ID string `json:"id"` + LdapID string `json:"ldapId"` + Name string `json:"name"` +} + +func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + helper := JsonParser{} + + if config.ManagementEndpoint == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing") + } + + if config.APIToken == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing") + } + + credentials := &PocketIdCredentials{ + clientConfig: config, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + } + + return &PocketIdManager{ + managementEndpoint: config.ManagementEndpoint, + apiToken: config.APIToken, + httpClient: httpClient, + credentials: credentials, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) { + var MethodsWithBody = []string{http.MethodPost, http.MethodPut} + if !slices.Contains(MethodsWithBody, method) && body != "" { + return nil, fmt.Errorf("Body provided to unsupported method: %s", method) + } + + reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource) + if query != nil { + reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode()) + } + var req *http.Request + var err error + if body != "" { + req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body)) + } else { + req, err = http.NewRequestWithContext(ctx, method, reqURL, nil) + } + if err != nil { + return nil, err + } + + req.Header.Add("X-API-KEY", p.apiToken) + + if body != "" { + req.Header.Add("content-type", "application/json") + req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength)) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// getAllUsersPaginated fetches all users from PocketID API using pagination +func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) { + var allUsers []pocketIdUserDto + currentPage := 1 + + for { + params := url.Values{} + // Copy existing search parameters + for key, values := range searchParams { + params[key] = values + } + + params.Set("pagination[limit]", "100") + params.Set("pagination[page]", fmt.Sprintf("%d", currentPage)) + + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + var profiles pocketIdPaginatedUserDto + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + allUsers = append(allUsers, profiles.Data...) + + // Check if we've reached the last page + if currentPage >= profiles.Pagination.TotalPages { + break + } + + currentPage++ + } + + return allUsers, nil +} + +func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { + return nil +} + +func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) { + body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + var user pocketIdUserDto + err = p.helper.Unmarshal(body, &user) + if err != nil { + return nil, err + } + + userData := user.userData() + userData.AppMetadata = appMetadata + + return userData, nil +} + +func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAccount() + } + + users := make([]*UserData, 0) + for _, profile := range allUsers { + userData := profile.userData() + userData.AppMetadata.WTAccountID = accountId + + users = append(users, userData) + } + return users, nil +} + +func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + indexedUsers := make(map[string][]*UserData) + for _, profile := range allUsers { + userData := profile.userData() + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + } + + return indexedUsers, nil +} + +func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + firstLast := strings.Split(name, " ") + + createUser := pocketIdUserCreateDto{ + Disabled: false, + DisplayName: name, + Email: email, + FirstName: firstLast[0], + LastName: firstLast[1], + Username: firstLast[0] + "." + firstLast[1], + } + payload, err := p.helper.Marshal(createUser) + if err != nil { + return nil, err + } + + body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload)) + if err != nil { + return nil, err + } + var newUser pocketIdUserDto + err = p.helper.Unmarshal(body, &newUser) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountCreateUser() + } + var pending bool = true + ret := &UserData{ + Email: email, + Name: name, + ID: newUser.ID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &pending, + WTInvitedBy: invitedByEmail, + }, + } + return ret, nil +} + +func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + params := url.Values{ + // This value a + "search": []string{email}, + } + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + var profiles struct{ data []pocketIdUserDto } + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles.data { + users = append(users, profile.userData()) + } + return users, nil +} + +func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "") + if err != nil { + return err + } + return nil +} + +func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "") + if err != nil { + return err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + +var _ Manager = (*PocketIdManager)(nil) + +type PocketIdClientConfig struct { + APIToken string + ManagementEndpoint string +} + +type PocketIdCredentials struct { + clientConfig PocketIdClientConfig + helper ManagerHelper + httpClient ManagerHTTPClient + appMetrics telemetry.AppMetrics +} + +var _ ManagerCredentials = (*PocketIdCredentials)(nil) + +func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) { + return JWTToken{}, nil +} diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go new file mode 100644 index 000000000..49075a0d3 --- /dev/null +++ b/management/server/idp/pocketid_test.go @@ -0,0 +1,138 @@ +package idp + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + + +func TestNewPocketIdManager(t *testing.T) { + type test struct { + name string + inputConfig PocketIdClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := PocketIdClientConfig{ + APIToken: "api_token", + ManagementEndpoint: "http://localhost", + } + + tests := []test{ + { + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + }, + { + name: "Missing ManagementEndpoint", + inputConfig: PocketIdClientConfig{ + APIToken: defaultTestConfig.APIToken, + ManagementEndpoint: "", + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + { + name: "Missing APIToken", + inputConfig: PocketIdClientConfig{ + APIToken: "", + ManagementEndpoint: defaultTestConfig.ManagementEndpoint, + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) + tc.assertErrFunc(t, err, tc.assertErrFuncMessage) + }) + } +} + +func TestPocketID_GetUserDataByID(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + md := AppMetadata{WTAccountID: "acc1"} + got, err := mgr.GetUserDataByID(context.Background(), "u1", md) + require.NoError(t, err) + assert.Equal(t, "u1", got.ID) + assert.Equal(t, "user1@example.com", got.Email) + assert.Equal(t, "User One", got.Name) + assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) +} + +func TestPocketID_GetAccount_WithPagination(t *testing.T) { + // Single page response with two users + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + users, err := mgr.GetAccount(context.Background(), "accX") + require.NoError(t, err) + require.Len(t, users, 2) + assert.Equal(t, "u1", users[0].ID) + assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "u2", users[1].ID) +} + +func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + accounts, err := mgr.GetAllAccounts(context.Background()) + require.NoError(t, err) + require.Len(t, accounts[UnsetAccountID], 2) +} + +func TestPocketID_CreateUser(t *testing.T) { + client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") + require.NoError(t, err) + assert.Equal(t, "newid", ud.ID) + assert.Equal(t, "new@example.com", ud.Email) + assert.Equal(t, "New User", ud.Name) + assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) + if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { + assert.True(t, *ud.AppMetadata.WTPendingInvite) + } + assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) +} + +func TestPocketID_InviteAndDeleteUser(t *testing.T) { + // Same mock for both calls; returns OK with empty JSON + client := &mockHTTPClient{code: 200, resBody: `{}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + err = mgr.InviteUserByID(context.Background(), "u1") + require.NoError(t, err) + + err = mgr.DeleteUser(context.Background(), "u1") + require.NoError(t, err) +} From 277aa2b7cc82368d6ec0c6f73acfdfa676ad2c9b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:13:41 +0200 Subject: [PATCH 026/120] [client] Fix missing flag values in profiles (#4650) --- client/server/server.go | 7 + client/server/setconfig_test.go | 298 ++++++++++++++++++++++++++++++++ 2 files changed, 305 insertions(+) create mode 100644 client/server/setconfig_test.go diff --git a/client/server/server.go b/client/server/server.go index e6de608c5..89f50a1ef 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.CustomDNSAddress = []byte{} } + config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + + if msg.DnsRouteInterval != nil { + interval := msg.DnsRouteInterval.AsDuration() + config.DNSRouteInterval = &interval + } + config.RosenpassEnabled = msg.RosenpassEnabled config.RosenpassPermissive = msg.RosenpassPermissive config.DisableAutoConnect = msg.DisableAutoConnect diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go new file mode 100644 index 000000000..1260bcc78 --- /dev/null +++ b/client/server/setconfig_test.go @@ -0,0 +1,298 @@ +package server + +import ( + "context" + "os/user" + "path/filepath" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" +) + +// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config. +// This test uses reflection to detect when new fields are added but not handled in SetConfig. +func TestSetConfig_AllFieldsSaved(t *testing.T) { + tempDir := t.TempDir() + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origDefaultConfigPath := profilemanager.DefaultConfigPath + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.ConfigDirOverride = tempDir + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json") + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + profilemanager.DefaultConfigPath = origDefaultConfigPath + profilemanager.ConfigDirOverride = "" + }) + + currUser, err := user.Current() + require.NoError(t, err) + + profName := "test-profile" + + ic := profilemanager.ConfigInput{ + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: "https://api.netbird.io:443", + } + _, err = profilemanager.UpdateOrCreateConfig(ic) + require.NoError(t, err) + + pm := profilemanager.ServiceManager{} + err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + }) + require.NoError(t, err) + + ctx := context.Background() + s := New(ctx, "console", "", false, false) + + rosenpassEnabled := true + rosenpassPermissive := true + serverSSHAllowed := true + interfaceName := "utun100" + wireguardPort := int64(51820) + preSharedKey := "test-psk" + disableAutoConnect := true + networkMonitor := true + disableClientRoutes := true + disableServerRoutes := true + disableDNS := true + disableFirewall := true + blockLANAccess := true + disableNotifications := true + lazyConnectionEnabled := true + blockInbound := true + mtu := int64(1280) + + req := &proto.SetConfigRequest{ + ProfileName: profName, + Username: currUser.Username, + ManagementUrl: "https://new-api.netbird.io:443", + AdminURL: "https://new-admin.netbird.io", + RosenpassEnabled: &rosenpassEnabled, + RosenpassPermissive: &rosenpassPermissive, + ServerSSHAllowed: &serverSSHAllowed, + InterfaceName: &interfaceName, + WireguardPort: &wireguardPort, + OptionalPreSharedKey: &preSharedKey, + DisableAutoConnect: &disableAutoConnect, + NetworkMonitor: &networkMonitor, + DisableClientRoutes: &disableClientRoutes, + DisableServerRoutes: &disableServerRoutes, + DisableDns: &disableDNS, + DisableFirewall: &disableFirewall, + BlockLanAccess: &blockLANAccess, + DisableNotifications: &disableNotifications, + LazyConnectionEnabled: &lazyConnectionEnabled, + BlockInbound: &blockInbound, + NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, + CleanNATExternalIPs: false, + CustomDNSAddress: []byte("1.1.1.1:53"), + ExtraIFaceBlacklist: []string{"eth1", "eth2"}, + DnsLabels: []string{"label1", "label2"}, + CleanDNSLabels: false, + DnsRouteInterval: durationpb.New(2 * time.Minute), + Mtu: &mtu, + } + + _, err = s.SetConfig(ctx, req) + require.NoError(t, err) + + profState := profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + } + cfgPath, err := profState.FilePath() + require.NoError(t, err) + + cfg, err := profilemanager.GetConfig(cfgPath) + require.NoError(t, err) + + require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String()) + require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String()) + require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled) + require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive) + require.NotNil(t, cfg.ServerSSHAllowed) + require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed) + require.Equal(t, interfaceName, cfg.WgIface) + require.Equal(t, int(wireguardPort), cfg.WgPort) + require.Equal(t, preSharedKey, cfg.PreSharedKey) + require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect) + require.NotNil(t, cfg.NetworkMonitor) + require.Equal(t, networkMonitor, *cfg.NetworkMonitor) + require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes) + require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes) + require.Equal(t, disableDNS, cfg.DisableDNS) + require.Equal(t, disableFirewall, cfg.DisableFirewall) + require.Equal(t, blockLANAccess, cfg.BlockLANAccess) + require.NotNil(t, cfg.DisableNotifications) + require.Equal(t, disableNotifications, *cfg.DisableNotifications) + require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled) + require.Equal(t, blockInbound, cfg.BlockInbound) + require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs) + require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress) + // IFaceBlackList contains defaults + extras + require.Contains(t, cfg.IFaceBlackList, "eth1") + require.Contains(t, cfg.IFaceBlackList, "eth2") + require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList()) + require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval) + require.Equal(t, uint16(mtu), cfg.MTU) + + verifyAllFieldsCovered(t, req) +} + +// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest. +// If a new field is added to SetConfigRequest, this function will fail the test, +// forcing the developer to update both the SetConfig handler and this test. +func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { + t.Helper() + + metadataFields := map[string]bool{ + "state": true, // protobuf internal + "sizeCache": true, // protobuf internal + "unknownFields": true, // protobuf internal + "Username": true, // metadata + "ProfileName": true, // metadata + "CleanNATExternalIPs": true, // control flag for clearing + "CleanDNSLabels": true, // control flag for clearing + } + + expectedFields := map[string]bool{ + "ManagementUrl": true, + "AdminURL": true, + "RosenpassEnabled": true, + "RosenpassPermissive": true, + "ServerSSHAllowed": true, + "InterfaceName": true, + "WireguardPort": true, + "OptionalPreSharedKey": true, + "DisableAutoConnect": true, + "NetworkMonitor": true, + "DisableClientRoutes": true, + "DisableServerRoutes": true, + "DisableDns": true, + "DisableFirewall": true, + "BlockLanAccess": true, + "DisableNotifications": true, + "LazyConnectionEnabled": true, + "BlockInbound": true, + "NatExternalIPs": true, + "CustomDNSAddress": true, + "ExtraIFaceBlacklist": true, + "DnsLabels": true, + "DnsRouteInterval": true, + "Mtu": true, + } + + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unexpectedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + if metadataFields[fieldName] { + continue + } + + if !expectedFields[fieldName] { + unexpectedFields = append(unexpectedFields, fieldName) + } + } + + if len(unexpectedFields) > 0 { + t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields) + } +} + +// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest. +// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq. +func TestCLIFlags_MappedToSetConfig(t *testing.T) { + // Map of CLI flag names to their corresponding SetConfigRequest field names. + // This map must be updated when adding new config-related CLI flags. + flagToField := map[string]string{ + "management-url": "ManagementUrl", + "admin-url": "AdminURL", + "enable-rosenpass": "RosenpassEnabled", + "rosenpass-permissive": "RosenpassPermissive", + "allow-server-ssh": "ServerSSHAllowed", + "interface-name": "InterfaceName", + "wireguard-port": "WireguardPort", + "preshared-key": "OptionalPreSharedKey", + "disable-auto-connect": "DisableAutoConnect", + "network-monitor": "NetworkMonitor", + "disable-client-routes": "DisableClientRoutes", + "disable-server-routes": "DisableServerRoutes", + "disable-dns": "DisableDns", + "disable-firewall": "DisableFirewall", + "block-lan-access": "BlockLanAccess", + "block-inbound": "BlockInbound", + "enable-lazy-connection": "LazyConnectionEnabled", + "external-ip-map": "NatExternalIPs", + "dns-resolver-address": "CustomDNSAddress", + "extra-iface-blacklist": "ExtraIFaceBlacklist", + "extra-dns-labels": "DnsLabels", + "dns-router-interval": "DnsRouteInterval", + "mtu": "Mtu", + } + + // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). + fieldsWithoutCLIFlags := map[string]bool{ + "DisableNotifications": true, // Only settable via UI + } + + // Get all SetConfigRequest fields to verify our map is complete. + req := &proto.SetConfigRequest{} + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unmappedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + // Skip protobuf internal fields and metadata fields. + if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" { + continue + } + if fieldName == "Username" || fieldName == "ProfileName" { + continue + } + if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" { + continue + } + + // Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag. + mappedToCLI := false + for _, mappedField := range flagToField { + if mappedField == fieldName { + mappedToCLI = true + break + } + } + + hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName] + + if !mappedToCLI && !hasNoCLIFlag { + unmappedFields = append(unmappedFields, fieldName) + } + } + + if len(unmappedFields) > 0 { + t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+ + "Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+ + "add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields) + } + + t.Log("All SetConfigRequest fields are properly documented") +} From 8252ff41db5fed928a1666e84061ee30d19addf9 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:58:29 +0200 Subject: [PATCH 027/120] [client] Add bind activity listener to bypass udp sockets (#4646) --- client/iface/device.go | 1 + client/iface/device/device_android.go | 5 + client/iface/device/device_darwin.go | 5 + client/iface/device/device_ios.go | 5 + client/iface/device/device_kernel_unix.go | 5 + client/iface/device/device_netstack.go | 6 + client/iface/device/device_usp_unix.go | 5 + client/iface/device/device_windows.go | 5 + client/iface/device/endpoint_manager.go | 13 + client/iface/device_android.go | 1 + client/iface/iface.go | 11 + .../internal/lazyconn/activity/lazy_conn.go | 82 +++++ .../lazyconn/activity/listener_bind.go | 127 ++++++++ .../lazyconn/activity/listener_bind_test.go | 291 ++++++++++++++++++ .../lazyconn/activity/listener_test.go | 41 --- .../activity/{listener.go => listener_udp.go} | 35 ++- .../lazyconn/activity/listener_udp_test.go | 110 +++++++ client/internal/lazyconn/activity/manager.go | 45 ++- .../lazyconn/activity/manager_test.go | 38 ++- client/internal/lazyconn/wgiface.go | 2 + 20 files changed, 760 insertions(+), 73 deletions(-) create mode 100644 client/iface/device/endpoint_manager.go create mode 100644 client/internal/lazyconn/activity/lazy_conn.go create mode 100644 client/internal/lazyconn/activity/listener_bind.go create mode 100644 client/internal/lazyconn/activity/listener_bind_test.go delete mode 100644 client/internal/lazyconn/activity/listener_test.go rename client/internal/lazyconn/activity/{listener.go => listener_udp.go} (64%) create mode 100644 client/internal/lazyconn/activity/listener_udp_test.go diff --git a/client/iface/device.go b/client/iface/device.go index 921f0ea98..c0c829825 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -23,4 +23,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index a731684cc..48346fc0f 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net { return nil } +// GetICEBind returns the ICEBind instance +func (t *WGTunDevice) GetICEBind() EndpointManager { + return t.iceBind +} + func routesToString(routes []string) string { return strings.Join(routes, ";") } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 390efe088..acd5f6f11 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 96e4c8bcf..f96edf992 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index cdac43a53..2a836f846 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns nil for kernel mode devices +func (t *TunKernelDevice) GetICEBind() EndpointManager { + return nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index e37321b68..40d8fdac8 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -21,6 +21,7 @@ type Bind interface { conn.Bind GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) ActivityRecorder() *bind.ActivityRecorder + EndpointManager } type TunNetstackDevice struct { @@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device { func (t *TunNetstackDevice) GetNet() *netstack.Net { return t.net } + +// GetICEBind returns the bind instance +func (t *TunNetstackDevice) GetICEBind() EndpointManager { + return t.bind +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4cdd70a32..24654fc03 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error { func (t *USPDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *USPDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f1023bc0a..96350df8a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/endpoint_manager.go b/client/iface/device/endpoint_manager.go new file mode 100644 index 000000000..b53888baa --- /dev/null +++ b/client/iface/device/endpoint_manager.go @@ -0,0 +1,13 @@ +package device + +import ( + "net" + "net/netip" +) + +// EndpointManager manages fake IP to connection mappings for userspace bind implementations. +// Implemented by bind.ICEBind and bind.RelayBindJS. +type EndpointManager interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) +} diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 4649b8b97..cdfcea48d 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,4 +21,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/iface.go b/client/iface/iface.go index 158672160..07235a995 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy { return w.wgProxyFactory.GetProxy() } +// GetBind returns the EndpointManager userspace bind mode. +func (w *WGIface) GetBind() device.EndpointManager { + w.mu.Lock() + defer w.mu.Unlock() + + if w.tun == nil { + return nil + } + return w.tun.GetICEBind() +} + // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind func (w *WGIface) IsUserspaceBind() bool { return w.userspaceBind diff --git a/client/internal/lazyconn/activity/lazy_conn.go b/client/internal/lazyconn/activity/lazy_conn.go new file mode 100644 index 000000000..2564a9905 --- /dev/null +++ b/client/internal/lazyconn/activity/lazy_conn.go @@ -0,0 +1,82 @@ +package activity + +import ( + "context" + "io" + "net" + "time" +) + +// lazyConn detects activity when WireGuard attempts to send packets. +// It does not deliver packets, only signals that activity occurred. +type lazyConn struct { + activityCh chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// newLazyConn creates a new lazyConn for activity detection. +func newLazyConn() *lazyConn { + ctx, cancel := context.WithCancel(context.Background()) + return &lazyConn{ + activityCh: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + } +} + +// Read blocks until the connection is closed. +func (c *lazyConn) Read(_ []byte) (n int, err error) { + <-c.ctx.Done() + return 0, io.EOF +} + +// Write signals activity detection when ICEBind routes packets to this endpoint. +func (c *lazyConn) Write(b []byte) (n int, err error) { + if c.ctx.Err() != nil { + return 0, io.EOF + } + + select { + case c.activityCh <- struct{}{}: + default: + } + + return len(b), nil +} + +// ActivityChan returns the channel that signals when activity is detected. +func (c *lazyConn) ActivityChan() <-chan struct{} { + return c.activityCh +} + +// Close closes the connection. +func (c *lazyConn) Close() error { + c.cancel() + return nil +} + +// LocalAddr returns the local address. +func (c *lazyConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// RemoteAddr returns the remote address. +func (c *lazyConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// SetDeadline sets the read and write deadlines. +func (c *lazyConn) SetDeadline(_ time.Time) error { + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (c *lazyConn) SetReadDeadline(_ time.Time) error { + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (c *lazyConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go new file mode 100644 index 000000000..792d04215 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -0,0 +1,127 @@ +package activity + +import ( + "fmt" + "net" + "net/netip" + "sync" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +type bindProvider interface { + GetBind() device.EndpointManager +} + +const ( + // lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers. + // The actual routing is done via fakeIP in ICEBind, not by this port. + lazyBindPort = 17473 +) + +// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode. +type BindListener struct { + wgIface WgInterface + peerCfg lazyconn.PeerConfig + done sync.WaitGroup + + lazyConn *lazyConn + bind device.EndpointManager + fakeIP netip.Addr +} + +// NewBindListener creates a listener that passes data directly through bind using LazyConn. +// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range. +func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) { + fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs) + if err != nil { + return nil, fmt.Errorf("derive fake IP: %w", err) + } + + d := &BindListener{ + wgIface: wgIface, + peerCfg: cfg, + bind: bind, + fakeIP: fakeIP, + } + + if err := d.setupLazyConn(); err != nil { + return nil, fmt.Errorf("setup lazy connection: %v", err) + } + + d.done.Add(1) + return d, nil +} + +// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. +// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). +// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { + if len(allowedIPs) == 0 { + return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") + } + + ourNetwork := wgIface.Address().Network + + var peerIP netip.Addr + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if !ip.Is4() { + continue + } + if ourNetwork.Contains(ip) { + peerIP = ip + break + } + } + + if !peerIP.IsValid() { + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + } + + octets := peerIP.As4() + fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) + return fakeIP, nil +} + +func (d *BindListener) setupLazyConn() error { + d.lazyConn = newLazyConn() + d.bind.SetEndpoint(d.fakeIP, d.lazyConn) + + endpoint := &net.UDPAddr{ + IP: d.fakeIP.AsSlice(), + Port: lazyBindPort, + } + return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil) +} + +// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed. +func (d *BindListener) ReadPackets() { + select { + case <-d.lazyConn.ActivityChan(): + d.peerCfg.Log.Infof("activity detected via LazyConn") + case <-d.lazyConn.ctx.Done(): + d.peerCfg.Log.Infof("exit from activity listener") + } + + d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey) + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { + d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) + } + + _ = d.lazyConn.Close() + d.bind.RemoveEndpoint(d.fakeIP) + d.done.Done() +} + +// Close stops the listener and cleans up resources. +func (d *BindListener) Close() { + d.peerCfg.Log.Infof("closing activity listener (LazyConn)") + + if err := d.lazyConn.Close(); err != nil { + d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err) + } + + d.done.Wait() +} diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go new file mode 100644 index 000000000..f86dd3877 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -0,0 +1,291 @@ +package activity + +import ( + "net" + "net/netip" + "runtime" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/lazyconn" + peerid "github.com/netbirdio/netbird/client/internal/peer/id" +) + +func isBindListenerPlatform() bool { + return runtime.GOOS == "windows" || runtime.GOOS == "js" +} + +// mockEndpointManager implements device.EndpointManager for testing +type mockEndpointManager struct { + endpoints map[netip.Addr]net.Conn +} + +func newMockEndpointManager() *mockEndpointManager { + return &mockEndpointManager{ + endpoints: make(map[netip.Addr]net.Conn), + } +} + +func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + m.endpoints[fakeIP] = conn +} + +func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) { + delete(m.endpoints, fakeIP) +} + +func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn { + return m.endpoints[fakeIP] +} + +// MockWGIfaceBind mocks WgInterface with bind support +type MockWGIfaceBind struct { + endpointMgr *mockEndpointManager +} + +func (m *MockWGIfaceBind) RemovePeer(string) error { + return nil +} + +func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { + return nil +} + +func (m *MockWGIfaceBind) IsUserspaceBind() bool { + return true +} + +func (m *MockWGIfaceBind) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +func (m *MockWGIfaceBind) GetBind() device.EndpointManager { + return m.endpointMgr +} + +func TestBindListener_Creation(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + expectedFakeIP := netip.MustParseAddr("127.2.0.2") + conn := mockEndpointMgr.GetEndpoint(expectedFakeIP) + require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager") + + _, ok := conn.(*lazyConn) + assert.True(t, ok, "Registered endpoint should be a lazyConn") + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestBindListener_ActivityDetection(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + fakeIP := listener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection") +} + +func TestBindListener_Close(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + fakeIP := listener.fakeIP + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close") +} + +func TestManager_BindMode(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + err := mgr.MonitorPeerActivity(cfg) + require.NoError(t, err) + + listener, exists := mgr.GetPeerListener(cfg.PeerConnID) + require.True(t, exists, "Peer listener should be found") + + bindListener, ok := listener.(*BindListener) + require.True(t, ok, "Listener should be BindListener, got %T", listener) + + fakeIP := bindListener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case peerConnID := <-mgr.OnActivityChan: + assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notification") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity") +} + +func TestManager_BindMode_MultiplePeers(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer1 := &MocPeer{PeerID: "testPeer1"} + peer2 := &MocPeer{PeerID: "testPeer2"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg1 := lazyconn.PeerConfig{ + PublicKey: peer1.PeerID, + PeerConnID: peer1.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + cfg2 := lazyconn.PeerConfig{ + PublicKey: peer2.PeerID, + PeerConnID: peer2.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")}, + Log: log.WithField("peer", "testPeer2"), + } + + err := mgr.MonitorPeerActivity(cfg1) + require.NoError(t, err) + + err = mgr.MonitorPeerActivity(cfg2) + require.NoError(t, err) + + listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID) + require.True(t, exists, "Peer1 listener should be found") + bindListener1 := listener1.(*BindListener) + + listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID) + require.True(t, exists, "Peer2 listener should be found") + bindListener2 := listener2.(*BindListener) + + conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP) + require.NotNil(t, conn1, "Peer1 endpoint should be registered") + _, err = conn1.Write([]byte{0x01}) + require.NoError(t, err) + + conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP) + require.NotNil(t, conn2, "Peer2 endpoint should be registered") + _, err = conn2.Write([]byte{0x02}) + require.NoError(t, err) + + receivedPeers := make(map[peerid.ConnID]bool) + for i := 0; i < 2; i++ { + select { + case peerConnID := <-mgr.OnActivityChan: + receivedPeers[peerConnID] = true + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notifications") + } + } + + assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received") + assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received") +} diff --git a/client/internal/lazyconn/activity/listener_test.go b/client/internal/lazyconn/activity/listener_test.go deleted file mode 100644 index 98d7838d2..000000000 --- a/client/internal/lazyconn/activity/listener_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package activity - -import ( - "testing" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/lazyconn" -) - -func TestNewListener(t *testing.T) { - peer := &MocPeer{ - PeerID: "examplePublicKey1", - } - - cfg := lazyconn.PeerConfig{ - PublicKey: peer.PeerID, - PeerConnID: peer.ConnID(), - Log: log.WithField("peer", "examplePublicKey1"), - } - - l, err := NewListener(MocWGIface{}, cfg) - if err != nil { - t.Fatalf("failed to create listener: %v", err) - } - - chanClosed := make(chan struct{}) - go func() { - defer close(chanClosed) - l.ReadPackets() - }() - - time.Sleep(1 * time.Second) - l.Close() - - select { - case <-chanClosed: - case <-time.After(time.Second): - } -} diff --git a/client/internal/lazyconn/activity/listener.go b/client/internal/lazyconn/activity/listener_udp.go similarity index 64% rename from client/internal/lazyconn/activity/listener.go rename to client/internal/lazyconn/activity/listener_udp.go index 817ff00c3..e0b09be6c 100644 --- a/client/internal/lazyconn/activity/listener.go +++ b/client/internal/lazyconn/activity/listener_udp.go @@ -11,26 +11,27 @@ import ( "github.com/netbirdio/netbird/client/internal/lazyconn" ) -// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking -type Listener struct { +// UDPListener uses UDP sockets for activity detection in kernel mode. +type UDPListener struct { wgIface WgInterface peerCfg lazyconn.PeerConfig conn *net.UDPConn endpoint *net.UDPAddr done sync.Mutex - isClosed atomic.Bool // use to avoid error log when closing the listener + isClosed atomic.Bool } -func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) { - d := &Listener{ +// NewUDPListener creates a listener that detects activity via UDP socket reads. +func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) { + d := &UDPListener{ wgIface: wgIface, peerCfg: cfg, } conn, err := d.newConn() if err != nil { - return nil, fmt.Errorf("failed to creating activity listener: %v", err) + return nil, fmt.Errorf("create UDP connection: %v", err) } d.conn = conn d.endpoint = conn.LocalAddr().(*net.UDPAddr) @@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error if err := d.createEndpoint(); err != nil { return nil, err } + d.done.Lock() - cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String()) + cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String()) return d, nil } -func (d *Listener) ReadPackets() { +// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed. +func (d *UDPListener) ReadPackets() { for { n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) if err != nil { @@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() { } d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) - if err := d.removeEndpoint(); err != nil { + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) } - _ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection" + // Ignore close error as it may return "use of closed network connection" if already closed. + _ = d.conn.Close() d.done.Unlock() } -func (d *Listener) Close() { +// Close stops the listener and cleans up resources. +func (d *UDPListener) Close() { d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.isClosed.Store(true) @@ -82,16 +87,12 @@ func (d *Listener) Close() { d.done.Lock() } -func (d *Listener) removeEndpoint() error { - return d.wgIface.RemovePeer(d.peerCfg.PublicKey) -} - -func (d *Listener) createEndpoint() error { +func (d *UDPListener) createEndpoint() error { d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String()) return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil) } -func (d *Listener) newConn() (*net.UDPConn, error) { +func (d *UDPListener) newConn() (*net.UDPConn, error) { addr := &net.UDPAddr{ Port: 0, IP: listenIP, diff --git a/client/internal/lazyconn/activity/listener_udp_test.go b/client/internal/lazyconn/activity/listener_udp_test.go new file mode 100644 index 000000000..d2adb9bf4 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_udp_test.go @@ -0,0 +1,110 @@ +package activity + +import ( + "net" + "net/netip" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +func TestUDPListener_Creation(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + require.NotNil(t, listener.conn) + require.NotNil(t, listener.endpoint) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestUDPListener_ActivityDetection(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + conn, err := net.Dial("udp", listener.conn.LocalAddr().String()) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } +} + +func TestUDPListener_Close(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed") +} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 915fb9cb8..db283ec9a 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -1,21 +1,32 @@ package activity import ( + "errors" "net" "net/netip" + "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) +// listener defines the contract for activity detection listeners. +type listener interface { + ReadPackets() + Close() +} + type WgInterface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + IsUserspaceBind() bool + Address() wgaddr.Address } type Manager struct { @@ -23,7 +34,7 @@ type Manager struct { wgIface WgInterface - peers map[peerid.ConnID]*Listener + peers map[peerid.ConnID]listener done chan struct{} mu sync.Mutex @@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager { m := &Manager{ OnActivityChan: make(chan peerid.ConnID, 1), wgIface: wgIface, - peers: make(map[peerid.ConnID]*Listener), + peers: make(map[peerid.ConnID]listener), done: make(chan struct{}), } return m @@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error { return nil } - listener, err := NewListener(m.wgIface, peerCfg) + listener, err := m.createListener(peerCfg) if err != nil { return err } - m.peers[peerCfg.PeerConnID] = listener + m.peers[peerCfg.PeerConnID] = listener go m.waitForTraffic(listener, peerCfg.PeerConnID) return nil } +func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) { + if !m.wgIface.IsUserspaceBind() { + return NewUDPListener(m.wgIface, peerCfg) + } + + // BindListener is only used on Windows and JS platforms: + // - JS: Cannot listen to UDP sockets + // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default + // gateway points to, preventing them from reaching the loopback interface. + // BindListener bypasses this by passing data directly through the bind. + if runtime.GOOS != "windows" && runtime.GOOS != "js" { + return NewUDPListener(m.wgIface, peerCfg) + } + + provider, ok := m.wgIface.(bindProvider) + if !ok { + return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") + } + + return NewBindListener(m.wgIface, provider.GetBind(), peerCfg) +} + func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) { m.mu.Lock() defer m.mu.Unlock() @@ -82,8 +115,8 @@ func (m *Manager) Close() { } } -func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) { - listener.ReadPackets() +func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) { + l.ReadPackets() m.mu.Lock() if _, ok := m.peers[peerConnID]; !ok { diff --git a/client/internal/lazyconn/activity/manager_test.go b/client/internal/lazyconn/activity/manager_test.go index ae6c31da4..0768d9219 100644 --- a/client/internal/lazyconn/activity/manager_test.go +++ b/client/internal/lazyconn/activity/manager_test.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) @@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error { func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { return nil - } -// Add this method to the Manager struct -func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) { +func (m MocWGIface) IsUserspaceBind() bool { + return false +} + +func (m MocWGIface) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +// GetPeerListener is a test helper to access listeners +func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) { m.mu.Lock() defer m.mu.Unlock() - listener, exists := m.peers[peerConnID] - return listener, exists + l, exists := m.peers[peerConnID] + return l, exists } func TestManager_MonitorPeerActivity(t *testing.T) { @@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) { t.Fatalf("peer listener not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + // Get the UDP listener's address for triggering + udpListener, ok := listener.(*UDPListener) + if !ok { + t.Fatalf("expected UDPListener") + } + if err := trigger(udpListener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) { t.Fatalf("failed to monitor peer activity: %v", err) } - addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String() + listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID) + udpListener, _ := listener.(*UDPListener) + addr := udpListener.conn.LocalAddr().String() mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID) @@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer1 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener1, _ := listener.(*UDPListener) + if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer2 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener2, _ := listener.(*UDPListener) + if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } diff --git a/client/internal/lazyconn/wgiface.go b/client/internal/lazyconn/wgiface.go index 0351904f7..0626c1815 100644 --- a/client/internal/lazyconn/wgiface.go +++ b/client/internal/lazyconn/wgiface.go @@ -7,6 +7,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/monotime" ) @@ -14,5 +15,6 @@ type WGIface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error IsUserspaceBind() bool + Address() wgaddr.Address LastActivities() map[string]monotime.Time } From 3abae0bd170c95b79378ab1d59ded7fb83ca985e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:16:51 +0200 Subject: [PATCH 028/120] [client] Set default wg port for new profiles (#4651) --- client/internal/profilemanager/config.go | 1 + client/internal/profilemanager/config_test.go | 92 +++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 4e6b422f6..f03822089 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) { config := &Config{ // defaults to false only for new (post 0.26) configurations ServerSSHAllowed: util.False(), + WgPort: iface.DefaultWgPort, } if _, err := config.apply(input); err != nil { diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index 45e37bf0e..90bde7707 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -5,11 +5,14 @@ import ( "errors" "os" "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/util" ) @@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) { } } +func TestNewProfileDefaults(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + }) + require.NoError(t, err, "should create new config") + + assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default") + assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default") + assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated") + assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated") + assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default") + assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820") + assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default") + assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default") + assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set") + assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set") + assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults") + + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS") + assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS") + } +} + +func TestWireguardPortZeroExplicit(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // Create a new profile with explicit port 0 (random port) + explicitZero := 0 + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: &explicitZero, + }) + require.NoError(t, err, "should create config with explicit port 0") + + assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user") + + // Verify it persists + readConfig, err := GetConfig(configPath) + require.NoError(t, err) + assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file") +} + +func TestWireguardPortDefaultVsExplicit(t *testing.T) { + tests := []struct { + name string + wireguardPort *int + expectedPort int + description string + }{ + { + name: "no port specified uses default", + wireguardPort: nil, + expectedPort: iface.DefaultWgPort, + description: "When user doesn't specify port, default to 51820", + }, + { + name: "explicit zero for random port", + wireguardPort: func() *int { v := 0; return &v }(), + expectedPort: 0, + description: "When user explicitly sets 0, use 0 for random port", + }, + { + name: "explicit custom port", + wireguardPort: func() *int { v := 52000; return &v }(), + expectedPort: 52000, + description: "When user sets custom port, use that port", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: tt.wireguardPort, + }) + require.NoError(t, err, tt.description) + assert.Equal(t, tt.expectedPort, config.WgPort, tt.description) + }) + } +} + func TestUpdateOldManagementURL(t *testing.T) { tests := []struct { name string From af95aabb0375f51872a4ff69dd12f412d716955e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 16 Oct 2025 17:15:39 +0200 Subject: [PATCH 029/120] Handle the case when the service has already been down and the status recorder is not available (#4652) --- client/internal/debug/wgshow.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/internal/debug/wgshow.go b/client/internal/debug/wgshow.go index e4b4c2368..8233ca510 100644 --- a/client/internal/debug/wgshow.go +++ b/client/internal/debug/wgshow.go @@ -14,6 +14,9 @@ type WGIface interface { } func (g *BundleGenerator) addWgShow() error { + if g.statusRecorder == nil { + return fmt.Errorf("no status recorder available for wg show") + } result, err := g.statusRecorder.PeersStatus() if err != nil { return err From 3cdb10cde765621086d15e97da150fee884356f6 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:09:39 +0200 Subject: [PATCH 030/120] [client] Remove rule squashing (#4653) --- client/firewall/iptables/acl_linux.go | 1 - client/internal/acl/manager.go | 157 +-------- client/internal/acl/manager_test.go | 486 -------------------------- 3 files changed, 3 insertions(+), 641 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index ed8a7403b..d78372c9e 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return "" } - // Include action in the ipset name to prevent squashing rules with different actions actionSuffix := "" if action == firewall.ActionDrop { actionSuffix = "-drop" diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 5ca950297..965decc73 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -29,11 +29,6 @@ type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } -type protoMatch struct { - ips map[string]int - policyID []byte -} - // DefaultManager uses firewall manager to handle type DefaultManager struct { firewall firewall.Manager @@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout } func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { - rules, squashedProtocols := d.squashAcceptRules(networkMap) + rules := networkMap.FirewallRules enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig.SshEnabled - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - enableSSH = enableSSH && !ok - } - if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { - enableSSH = enableSSH && !ok - } - // if TCP protocol rules not squashed and SSH enabled - // we add default firewall rule which accepts connection to any peer - // in the network by SSH (TCP 22 port). + // If SSH enabled, add default firewall rule which accepts connection to any peer + // in the network by SSH (TCP port defined by ssh.DefaultSSHPort). if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", @@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID( return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } -// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type -// to all peers in the network map to one rule which just accepts that type of the traffic. -// -// NOTE: It will not squash two rules for same protocol if one covers all peers in the network, -// but other has port definitions or has drop policy. -func (d *DefaultManager) squashAcceptRules( - networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { - totalIPs := 0 - for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { - for range p.AllowedIps { - totalIPs++ - } - } - - in := map[mgmProto.RuleProtocol]*protoMatch{} - out := map[mgmProto.RuleProtocol]*protoMatch{} - - // trace which type of protocols was squashed - squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} - - // this function we use to do calculation, can we squash the rules by protocol or not. - // We summ amount of Peers IP for given protocol we found in original rules list. - // But we zeroed the IP's for protocol if: - // 1. Any of the rule has DROP action type. - // 2. Any of rule contains Port. - // - // We zeroed this to notify squash function that this protocol can't be squashed. - addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { - hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || - r.Port != "" || !portInfoEmpty(r.PortInfo) - - if hasPortRestrictions { - // Don't squash rules with port restrictions - protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} - return - } - - if _, ok := protocols[r.Protocol]; !ok { - protocols[r.Protocol] = &protoMatch{ - ips: map[string]int{}, - // store the first encountered PolicyID for this protocol - policyID: r.PolicyID, - } - } - - // special case, when we receive this all network IP address - // it means that rules for that protocol was already optimized on the - // management side - if r.PeerIP == "0.0.0.0" { - squashedRules = append(squashedRules, r) - squashedProtocols[r.Protocol] = struct{}{} - return - } - - ipset := protocols[r.Protocol].ips - - if _, ok := ipset[r.PeerIP]; ok { - return - } - ipset[r.PeerIP] = i - } - - for i, r := range networkMap.FirewallRules { - // calculate squash for different directions - if r.Direction == mgmProto.RuleDirection_IN { - addRuleToCalculationMap(i, r, in) - } else { - addRuleToCalculationMap(i, r, out) - } - } - - // order of squashing by protocol is important - // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.RuleProtocol{ - mgmProto.RuleProtocol_ALL, - mgmProto.RuleProtocol_ICMP, - mgmProto.RuleProtocol_TCP, - mgmProto.RuleProtocol_UDP, - } - - squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) { - for _, protocol := range protocolOrders { - match, ok := matches[protocol] - if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 { - // don't squash if : - // 1. Rules not cover all peers in the network - // 2. Rules cover only one peer in the network. - continue - } - - // add special rule 0.0.0.0 which allows all IP's in our firewall implementations - squashedRules = append(squashedRules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: direction, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: protocol, - PolicyID: match.policyID, - }) - squashedProtocols[protocol] = struct{}{} - - if protocol == mgmProto.RuleProtocol_ALL { - // if we have ALL traffic type squashed rule - // it allows all other type of traffic, so we can stop processing - break - } - } - } - - squash(in, mgmProto.RuleDirection_IN) - squash(out, mgmProto.RuleDirection_OUT) - - // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - return squashedRules, squashedProtocols - } - - if len(squashedRules) == 0 { - return networkMap.FirewallRules, squashedProtocols - } - - var rules []*mgmProto.FirewallRule - // filter out rules which was squashed from final list - // if we also have other not squashed rules. - for i, r := range networkMap.FirewallRules { - if _, ok := squashedProtocols[r.Protocol]; ok { - if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } - } - rules = append(rules, r) - } - - return append(rules, squashedRules...), squashedProtocols -} - // getRuleGroupingSelector takes all rule properties except IP address to build selector func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 664476ef4..daf4979ce 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) { }) } -func TestDefaultManagerSquashRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, 2, len(rules)) - - r := rules[0] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) - - r = rules[1] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) -} - -func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, len(networkMap.FirewallRules), len(rules)) -} - -func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) { - tests := []struct { - name string - rules []*mgmProto.FirewallRule - expectedCount int - description string - }{ - { - name: "should not squash rules with port ranges", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with port ranges should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with specific ports", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with specific ports should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with legacy port field", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - }, - expectedCount: 4, - description: "Rules with legacy port field should not be squashed", - }, - { - name: "should not squash rules with DROP action", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "Rules with DROP action should not be squashed", - }, - { - name: "should squash rules without port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 1, - description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule", - }, - { - name: "mixed rules should not squash protocol with port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "TCP should not be squashed because one rule has port restrictions", - }, - { - name: "should squash UDP but not TCP when TCP has port restrictions", - rules: []*mgmProto.FirewallRule{ - // TCP rules with port restrictions - should NOT be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - // UDP rules without port restrictions - SHOULD be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0) - description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: tt.rules, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - - assert.Equal(t, tt.expectedCount, len(rules), tt.description) - - // For squashed rules, verify we get the expected 0.0.0.0 rule - if tt.expectedCount == 1 { - assert.Equal(t, "0.0.0.0", rules[0].PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action) - } - }) - } -} - func TestPortInfoEmpty(t *testing.T) { tests := []struct { name string From 429d7d65855bec44481e572c83e1429ad0c50ebd Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:10:16 +0200 Subject: [PATCH 031/120] [client] Support BROWSER env for login (#4654) --- client/cmd/login.go | 11 ++++++++++- client/ui/client_ui.go | 7 +++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 3ac211805..40b55f858 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "os/exec" "os/user" "runtime" "strings" @@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := open.Run(verificationURIComplete); err != nil { + if err := openBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } } +// openBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func openBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} + // isUnixRunningDesktop checks if a Linux OS is running desktop environment func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 7c2000a9d..0043f228e 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -31,7 +31,6 @@ import ( "fyne.io/systray" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -633,7 +632,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { } func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := open.Run(loginResp.VerificationURIComplete) + err := openURL(loginResp.VerificationURIComplete) if err != nil { log.Errorf("opening the verification uri in the browser failed: %v", err) return err @@ -1487,6 +1486,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } func openURL(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + var err error switch runtime.GOOS { case "windows": From f5301230bfef5d8b3abaf60634646793fa7b63ac Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 17 Oct 2025 13:31:15 +0200 Subject: [PATCH 032/120] [client] Fix status showing P2P without connection (#4661) --- client/status/status.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/client/status/status.go b/client/status/status.go index db5b7dc0b..5e4fcd8dc 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -205,15 +205,18 @@ func mapPeers( localICEEndpoint := "" remoteICEEndpoint := "" relayServerAddress := "" - connType := "P2P" + connType := "-" lastHandshake := time.Time{} transferReceived := int64(0) transferSent := int64(0) isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if pbPeerState.Relayed { - connType = "Relayed" + if isPeerConnected { + connType = "P2P" + if pbPeerState.Relayed { + connType = "Relayed" + } } if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { From 0f9bfeff7cf36c81086c8dc91136c9194dbaf819 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 17 Oct 2025 14:47:11 -0300 Subject: [PATCH 033/120] [client] Security upgrade alpine from 3.22.0 to 3.22.2 #4618 --- client/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..5cd459357 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.22.0 +FROM alpine:3.22.2 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ From cd9a867ad02034570daa415fce6fe1dd47c1ccb8 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Fri, 17 Oct 2025 19:48:26 +0200 Subject: [PATCH 034/120] [client] Delete TURNConfig section from script (#4639) --- infrastructure_files/getting-started-with-zitadel.sh | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index bc326cd7e..09c5225ad 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -682,17 +682,6 @@ renderManagementJson() { "URI": "stun:$NETBIRD_DOMAIN:3478" } ], - "TURNConfig": { - "Turns": [ - { - "Proto": "udp", - "URI": "turn:$NETBIRD_DOMAIN:3478", - "Username": "$TURN_USER", - "Password": "$TURN_PASSWORD" - } - ], - "TimeBasedCredentials": false - }, "Relay": { "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], "CredentialsTTL": "24h", From 2fe2af38d29084f9bf226a83c8aecf7440c633bd Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:14:39 +0200 Subject: [PATCH 035/120] [client] Clean up match domain reg entries between config changes (#4676) --- client/internal/dns/host_windows.go | 12 ++- client/internal/dns/host_windows_test.go | 102 ++++++++++++++++++ .../internal/winregistry/volatile_windows.go | 59 ++++++++++ 3 files changed, 168 insertions(+), 5 deletions(-) create mode 100644 client/internal/dns/host_windows_test.go create mode 100644 client/internal/winregistry/volatile_windows.go diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index a14a01f40..74111d335 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/winregistry" ) var ( @@ -197,6 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } + if err := r.removeDNSMatchPolicies(); err != nil { + log.Errorf("cleanup old dns match policies: %s", err) + } + if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err != nil { @@ -204,9 +209,6 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager } r.nrptEntryCount = count } else { - if err := r.removeDNSMatchPolicies(); err != nil { - return fmt.Errorf("remove dns match policies: %w", err) - } r.nrptEntryCount = 0 } @@ -273,9 +275,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("remove existing dns policy: %w", err) } - regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) + regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) if err != nil { - return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) + return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) } defer closer(regKey) diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go new file mode 100644 index 000000000..19496bf5a --- /dev/null +++ b/client/internal/dns/host_windows_test.go @@ -0,0 +1,102 @@ +package dns + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" +) + +// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up +// when the number of match domains decreases between configuration changes. +func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + config5 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + {Domain: "domain3.com", MatchOnly: true}, + {Domain: "domain4.com", MatchOnly: true}, + {Domain: "domain5.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config5, nil) + require.NoError(t, err) + + // Verify all 5 entries exist + for i := 0; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after first config", i) + } + + config2 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config2, nil) + require.NoError(t, err) + + // Verify first 2 entries exist + for i := 0; i < 2; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after second config", i) + } + + // Verify entries 2-4 are cleaned up + for i := 2; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) + } +} + +func registryKeyExists(path string) (bool, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + k.Close() + return true, nil +} + +func cleanupRegistryKeys(*testing.T) { + cfg := ®istryConfigurator{nrptEntryCount: 10} + _ = cfg.removeDNSMatchPolicies() +} diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go new file mode 100644 index 000000000..a8e350fe7 --- /dev/null +++ b/client/internal/winregistry/volatile_windows.go @@ -0,0 +1,59 @@ +package winregistry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows/registry" +) + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + regCreateKeyExW = advapi.NewProc("RegCreateKeyExW") +) + +const ( + // Registry key options + regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted + regOptionVolatile = 0x1 // Key is not preserved when system is rebooted + + // Registry disposition values + regCreatedNewKey = 0x1 + regOpenedExistingKey = 0x2 +) + +// CreateVolatileKey creates a volatile registry key named path under open key root. +// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed. +// The access parameter specifies the access rights for the key to be created. +// +// Volatile keys are stored in memory and are automatically deleted when the system is shut down. +// This provides automatic cleanup without requiring manual registry maintenance. +func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) { + pathPtr, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, false, err + } + + var ( + handle syscall.Handle + disposition uint32 + ) + + ret, _, _ := regCreateKeyExW.Call( + uintptr(root), + uintptr(unsafe.Pointer(pathPtr)), + 0, // reserved + 0, // class + uintptr(regOptionVolatile), // options - volatile key + uintptr(access), // desired access + 0, // security attributes + uintptr(unsafe.Pointer(&handle)), + uintptr(unsafe.Pointer(&disposition)), + ) + + if ret != 0 { + return 0, false, syscall.Errno(ret) + } + + return registry.Key(handle), disposition == regOpenedExistingKey, nil +} From 96f71ff1e15b50eff87a7052432537dad2718b20 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 21 Oct 2025 19:23:11 +0200 Subject: [PATCH 036/120] [misc] Update tag name extraction in install.sh (#4677) --- release_files/install.sh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..6a2c5f458 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then NETBIRD_RELEASE=latest fi +TAG_NAME="" + get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then @@ -38,17 +40,19 @@ get_release() { local TAG="tags/${RELEASE}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi + OUTPUT="" if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}") else - curl -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -s "${URL}") fi + TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1) + echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+' } download_release_binary() { VERSION=$(get_release "$NETBIRD_RELEASE") + echo "Using the following tag name for binary installation: ${TAG_NAME}" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" From d80d47a469ab7d1d740d28ec836d65dc7776beec Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 22 Oct 2025 12:46:22 +0300 Subject: [PATCH 037/120] [management] Add peer disapproval reason (#4468) --- go.mod | 2 +- go.sum | 4 +- management/server/account/manager.go | 2 +- .../http/handlers/peers/peers_handler.go | 38 ++++++++++++------- management/server/integrated_validator.go | 24 +++++++++--- .../integrated_validator/interface.go | 1 + management/server/mock_server/account_mock.go | 6 +-- shared/management/http/api/openapi.yml | 3 ++ shared/management/http/api/types.gen.go | 6 +++ 9 files changed, 61 insertions(+), 25 deletions(-) diff --git a/go.mod b/go.mod index 31b45e881..79dd92e6b 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f + github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 6b0b298a7..f0065e081 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 h1:aXHS63QWf0Z5fDN19Swl6npdJjGMyXthAvvgW7rbKJQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/account/manager.go b/management/server/account/manager.go index a1ed9498b..fe9fb25c6 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -109,7 +109,7 @@ type Manager interface { GetIdpManager() idp.Manager UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 4b33495de..df89c616c 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) + reason := invalidPeers[peer.ID] + + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] + reason := invalidPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - h.setApprovalRequiredFlag(respBody, validPeersMap) + h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap) util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) { for _, peer := range respBody { - _, ok := approvedPeersMap[peer.Id] + _, ok := validPeersMap[peer.Id] if !ok { peer.ApprovalRequired = true + + reason := invalidPeersMap[peer.Id] + peer.DisapprovalReason = &reason } } } @@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core } - return &api.Peer{ + apiPeer := &api.Peer{ CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, @@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, } + + if !approved { + apiPeer.DisapprovalReason = &reason + } + + return apiPeer } func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 251c04273..e9a1c8701 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { var err error var groups []*types.Group var peers []*nbpeer.Peer @@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { - return nil, err + return nil, nil, err } settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } - return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + if err != nil { + return nil, nil, err + } + + invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra) + if err != nil { + return nil, nil, err + } + + return validPeers, invalidPeers, nil } type MockIntegratedValidator struct { @@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } +func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index be05c2527..26c338cb6 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -15,6 +15,7 @@ type IntegratedValidator interface { PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) Stop(ctx context.Context) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d160e7269..e87043f26 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { account, err := am.GetAccountFunc(ctx, accountID) if err != nil { - return nil, err + return nil, nil, err } approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} } - return approvedPeers, nil + return approvedPeers, nil, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 93578b1ae..4a5454002 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -463,6 +463,9 @@ components: description: (Cloud only) Indicates whether peer needs approval type: boolean example: true + disapproval_reason: + description: (Cloud only) Reason why the peer requires approval + type: string country_code: $ref: '#/components/schemas/CountryCode' city_name: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 3dbb32ef6..9611d26d6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1037,6 +1037,9 @@ type Peer struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1124,6 +1127,9 @@ type PeerBatch struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` From 6654e2dbf7c3666a432d7ff457eb3cc5dfdc59aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 23 Oct 2025 17:07:52 +0200 Subject: [PATCH 038/120] [client] Fix active profile name in debug bundle (#4689) --- client/cmd/debug.go | 8 +++++++- client/internal/debug/debug.go | 4 +++- client/ui/debug.go | 19 ++++++++++++++++--- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..d53c5f06b 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -307,8 +307,14 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string { if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), ) } return statusOutputString diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index ec920c5f3..442f54e71 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. -state.json: Anonymized client state dump containing netbird states. +state.json: Anonymized client state dump containing netbird states for the active profile. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error { return nil } + log.Debugf("Adding state file from: %s", path) + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, fs.ErrNotExist) { diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..bf9839dda 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -18,6 +18,7 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" @@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData( return "", err } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("Failed to get post-up status: %v", err) @@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return nil, fmt.Errorf("get client: %v", err) } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("failed to get status for debug bundle: %v", err) @@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) statusOutput = nbstatus.ParseToFullDetailSummary(overview) } From 709e24eb6f7b4beff4da86617bbcf518434b8764 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 24 Oct 2025 15:40:20 +0300 Subject: [PATCH 039/120] [signal] Fix HTTP/WebSocket proxy not using custom certificates (#4644) This pull request fixes a bug where the HTTP/WebSocket proxy server was not using custom TLS certificates when provided via --cert-file and --cert-key flags. Previously, only the gRPC server had TLS enabled with custom certificates, while the HTTP/WebSocket proxy ran without TLS. --- infrastructure_files/configure.sh | 5 ++++- infrastructure_files/docker-compose.yml.tmpl | 8 +++++++ signal/cmd/run.go | 23 ++++++++++---------- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..2bc49d3e5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,21 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console" + ] # Relay relay: diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 96873dee7..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -94,7 +94,7 @@ var ( startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -132,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -140,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -202,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) { From b9ef214ea54fc524b9685f883e288f65720fbd32 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 27 Oct 2025 18:35:32 +0100 Subject: [PATCH 040/120] [client] Fix macOS state-based dns cleanup (#4701) --- client/internal/dns/host_darwin.go | 41 ++++--- client/internal/dns/host_darwin_test.go | 111 ++++++++++++++++++ client/internal/dns/host_windows.go | 26 ++-- .../internal/dns/unclean_shutdown_darwin.go | 5 + 4 files changed, 151 insertions(+), 32 deletions(-) create mode 100644 client/internal/dns/host_darwin_test.go diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..71badf0d4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -13,6 +13,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - - if err := stateManager.UpdateState(&ShutdownState{}); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } - var ( searchDomains []string matchDomains []string ) - err = s.recordSystemDNSSettings(true) - if err != nil { + if err := s.recordSystemDNSSettings(true); err != nil { log.Errorf("unable to update record of System's DNS config: %s", err.Error()) } if config.RouteAll { searchDomains = append(searchDomains, "\"\"") - err = s.addLocalDNS() - if err != nil { - log.Infof("failed to enable split DNS") + if err := s.addLocalDNS(); err != nil { + log.Warnf("failed to add local DNS: %v", err) } + s.updateState(stateManager) } for _, dConf := range config.Domains { @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + var err error if len(matchDomains) != 0 { err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add match domains: %w", err) } + s.updateState(stateManager) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add search domains: %w", err) } + s.updateState(stateManager) if err := s.flushDNSCache(); err != nil { log.Errorf("failed to flush DNS cache: %v", err) @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (s *systemConfigurator) string() string { return "scutil" } @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { - err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) - if err != nil { - return fmt.Errorf("couldn't add local network DNS conf: %w", err) - } - } else { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") + return nil + } + + if err := s.addSearchDomains( + localKey, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + ); err != nil { + return fmt.Errorf("add search domains: %w", err) } return nil diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..c4efd17b0 --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,111 @@ +//go:build !ios + +package dns + +import ( + "context" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + defer func() { + require.NoError(t, sm.Stop(context.Background())) + }() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + config := HostDNSConfig{ + ServerIP: netip.MustParseAddr("100.64.0.1"), + ServerPort: 53, + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: true}, + }, + } + + err := configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + require.NoError(t, sm.PersistState(context.Background())) + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + defer func() { + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + }() + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + if exists { + t.Logf("Key %s exists before cleanup", key) + } + } + + sm2 := statemanager.New(stateFile) + sm2.RegisterState(&ShutdownState{}) + err = sm2.LoadState(&ShutdownState{}) + require.NoError(t, err) + + state := sm2.GetState(&ShutdownState{}) + if state == nil { + t.Skip("State not saved, skipping cleanup test") + } + + shutdownState, ok := state.(*ShutdownState) + require.True(t, ok) + + err = shutdownState.Cleanup() + require.NoError(t, err) + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + assert.False(t, exists, "Key %s should NOT exist after cleanup", key) + } +} + +func checkDNSKeyExists(key string) (bool, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") + output, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(output), "No such key") { + return false, nil + } + return false, err + } + return !strings.Contains(string(output), "No such key"), nil +} + +func removeTestDNSKey(key string) error { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") + _, err := cmd.CombinedOutput() + return err +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 74111d335..01b7edc48 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -179,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) var searchDomains, matchDomains []string for _, dConf := range config.Domains { @@ -212,13 +206,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager r.nrptEntryCount = 0 } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) if err := r.updateSearchDomains(searchDomains); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -229,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{ + Guid: r.guid, + GPO: r.gpo, + NRPTEntryCount: r.nrptEntryCount, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index 9bbdd2b56..f51b5cf8d 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -7,6 +7,7 @@ import ( ) type ShutdownState struct { + CreatedKeys []string } func (s *ShutdownState) Name() string { @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create host manager: %w", err) } + for _, key := range s.CreatedKeys { + manager.createdKeys[key] = struct{}{} + } + if err := manager.restoreUncleanShutdownDNS(); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } From eddea145216f822e6b2842d968b95df351f62533 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 27 Oct 2025 18:54:00 +0100 Subject: [PATCH 041/120] [client] Clean up bsd routes independently of the state file (#4688) --- client/internal/routemanager/manager.go | 2 +- .../routemanager/systemops/flush_nonbsd.go | 8 ++ .../internal/routemanager/systemops/state.go | 8 +- .../routemanager/systemops/systemops.go | 2 +- .../systemops/systemops_bsd_test.go | 4 +- .../systemops/systemops_generic_test.go | 6 +- .../routemanager/systemops/systemops_unix.go | 78 ++++++++++++++++++- client/internal/statemanager/manager.go | 2 +- client/server/state.go | 9 +++ 9 files changed, 106 insertions(+), 13 deletions(-) create mode 100644 client/internal/routemanager/systemops/flush_nonbsd.go diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 04513bbe4..d590dba0d 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -106,7 +106,7 @@ type DefaultManager struct { func NewManager(config ManagerConfig) *DefaultManager { mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(config.WGInterface, notifier) + sysOps := systemops.New(config.WGInterface, notifier) if runtime.GOOS == "windows" && config.WGInterface != nil { nbnet.SetVPNInterfaceName(config.WGInterface.Name()) diff --git a/client/internal/routemanager/systemops/flush_nonbsd.go b/client/internal/routemanager/systemops/flush_nonbsd.go new file mode 100644 index 000000000..f1c45d6cf --- /dev/null +++ b/client/internal/routemanager/systemops/flush_nonbsd.go @@ -0,0 +1,8 @@ +//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd) + +package systemops + +// FlushMarkedRoutes is a no-op on non-BSD platforms. +func (r *SysOps) FlushMarkedRoutes() error { + return nil +} diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 8e158711e..e0d045b07 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData((*ExclusionCounter)(s)) + sysOps := New(nil, nil) + sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable) + sysOps.refCounter.LoadData((*ExclusionCounter)(s)) - return sysops.refCounter.Flush() + return sysOps.refCounter.Flush() } func (s *ShutdownState) MarshalJSON() ([]byte, error) { diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 8da138117..c0ca21d22 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -83,7 +83,7 @@ type SysOps struct { localSubnetsCacheTime time.Time } -func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { +func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 0d892c162..ec4fc406e 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) { _, intf = setupDummyInterface(t) nexthop = Nexthop{netip.Addr{}, intf} - r := NewSysOps(nil, nil) + r := New(nil, nil) var wg sync.WaitGroup for i := 0; i < 1024; i++ { @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin nexthop := Nexthop{netip.Addr{}, netIntf} - r := NewSysOps(nil, nil) + r := New(nil, nil) err = r.addToRouteTable(prefix, nexthop) require.NoError(t, err, "Failed to add route to table") diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 32ea38a7a..d9b109beb 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, wgInterface.Close()) }) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d43c2d5bf..7089178fb 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -7,19 +7,39 @@ import ( "fmt" "net" "net/netip" + "os" "strconv" "syscall" "time" "unsafe" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/net/route" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) +const ( + envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" +) + +var routeProtoFlag int + +func init() { + switch os.Getenv(envRouteProtoFlag) { + case "2": + routeProtoFlag = unix.RTF_PROTO2 + case "3": + routeProtoFlag = unix.RTF_PROTO3 + default: + routeProtoFlag = unix.RTF_PROTO1 + } +} + func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout return r.cleanupRefCounter(stateManager) } +// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +func (r *SysOps) FlushMarkedRoutes() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + flushedCount := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + + routeInfo, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("Skipping route flush: %v", err) + continue + } + + if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() { + continue + } + + nexthop := Nexthop{ + IP: routeInfo.Gw, + Intf: routeInfo.Interface, + } + + if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err)) + continue + } + + flushedCount++ + log.Debugf("Flushed marked route: %s", routeInfo.Dst) + } + + if flushedCount > 0 { + log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount) + } + + return nberrors.FormatErrorOrNil(merr) +} + func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { return r.routeSocket(unix.RTM_ADD, prefix, nexthop) } @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func( func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { msg = &route.RouteMessage{ Type: action, - Flags: unix.RTF_UP, + Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, Seq: r.getSeq(), } diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 29f962ad2..2c9e46290 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - log.Debug("state file does not exist") + log.Debugf("state file %s does not exist", m.filePath) return nil, nil // nolint:nilnil } return nil, fmt.Errorf("read state file: %w", err) diff --git a/client/server/state.go b/client/server/state.go index 107f55154..1cf85cd37 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -10,7 +10,9 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error { merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) } + // clean up any remaining routes independently of the state file + if !nbnet.AdvancedRouting() { + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } From 7f08983207a8a39e4610e4b9945cf06b44cb293d Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 27 Oct 2025 22:16:17 +0300 Subject: [PATCH 042/120] Include expired and routing peers in DNS record filtering (#4708) --- management/server/types/account.go | 8 ++++-- management/server/types/account_test.go | 34 ++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/management/server/types/account.go b/management/server/types/account.go index f830023c7..50bdc6ab3 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -301,7 +301,7 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ Domain: peersCustomZone.Domain, Records: records, @@ -1682,7 +1682,7 @@ func peerSupportsPortRanges(peerVer string) bool { } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) peerIPs := make(map[string]struct{}) @@ -1693,6 +1693,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p peerIPs[peerToConnect.IP.String()] = struct{}{} } + for _, expiredPeer := range expiredPeers { + peerIPs[expiredPeer.IP.String()] = struct{}{} + } + for _, record := range customZone.Records { if _, exists := peerIPs[record.RData]; exists { filteredRecords = append(filteredRecords, record) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..32538933a 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { peer *nbpeer.Peer customZone nbdns.CustomZone peersToConnect []*nbpeer.Peer + expiredPeers []*nbpeer.Peer expectedRecords []nbdns.SimpleRecord }{ { @@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{}, + expiredPeers: []*nbpeer.Peer{}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, @@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { } return peers }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, }, + { + name: "expired peers are included in DNS entries", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + }, + expiredPeers: []*nbpeer.Peer{ + {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers) assert.Equal(t, len(tt.expectedRecords), len(result)) assert.ElementsMatch(t, tt.expectedRecords, result) }) From 4545ab9a529fd34dc05790866a38141ce9ba7165 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 27 Oct 2025 22:59:35 +0100 Subject: [PATCH 043/120] [management] rewire account manager to permissions manager (#4673) --- go.mod | 2 +- go.sum | 4 ++-- management/internals/server/modules.go | 8 +++++++- .../http/testing/testing_tools/channel/channel.go | 3 ++- management/server/permissions/manager.go | 6 ++++++ management/server/permissions/manager_mock.go | 13 +++++++++++++ 6 files changed, 31 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 79dd92e6b..06dc921f1 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index f0065e081..ce68ed99e 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 h1:aXHS63QWf0Z5fDN19Swl6npdJjGMyXthAvvgW7rbKJQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index daec4ef6f..209a20065 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - return integrations.InitPermissionsManager(s.Store()) + manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) + + s.AfterInit(func(s *BaseServer) { + manager.SetAccountManager(s.AccountManager()) + }) + + return manager }) } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 741f03f18..bdf56db6e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,9 +7,10 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..e6bdd2025 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -22,6 +23,7 @@ type Manager interface { ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) + SetAccountManager(accountManager account.Manager) } type managerImpl struct { @@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR return permissions, nil } + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + // no-op +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index fa115d628..ec9f263f9 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + account "github.com/netbirdio/netbird/management/server/account" modules "github.com/netbirdio/netbird/management/server/permissions/modules" operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" @@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) } +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() From 404cab90ba9c340b848e7f7a457812a20910ed50 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 28 Oct 2025 15:12:53 +0100 Subject: [PATCH 044/120] [client] Redirect dns forwarder port 5353 to new listening port 22054 (#4707) - Port dnat changes from https://github.com/netbirdio/netbird/pull/4015 (nftables/iptables/userspace) - For userspace: rewrite the original port to the target port - Remember original destination port in conntrack - Rewrite the source port back to the original port for replies - Redirect incoming port 5353 to 22054 (tcp/udp) - Revert port changes based on the network map received from management - Adjust tracer to show NAT stages --- client/firewall/iptables/manager_linux.go | 16 + client/firewall/iptables/router_linux.go | 48 +++ client/firewall/manager/firewall.go | 10 +- client/firewall/nftables/manager_linux.go | 16 + client/firewall/nftables/router_linux.go | 97 ++++++ client/firewall/uspfilter/conntrack/common.go | 2 + client/firewall/uspfilter/conntrack/tcp.go | 48 ++- .../firewall/uspfilter/conntrack/tcp_test.go | 8 +- client/firewall/uspfilter/conntrack/udp.go | 34 +- client/firewall/uspfilter/filter.go | 45 ++- client/firewall/uspfilter/log/log.go | 34 +- client/firewall/uspfilter/nat.go | 305 ++++++++++++++---- client/firewall/uspfilter/nat_bench_test.go | 124 +++++++ client/firewall/uspfilter/nat_test.go | 109 +++++++ client/firewall/uspfilter/tracer.go | 233 ++++++++++++- client/firewall/uspfilter/tracer_test.go | 30 ++ client/internal/dnsfwd/manager.go | 69 ++-- client/internal/engine.go | 40 +-- client/internal/netflow/logger/logger.go | 5 +- .../routemanager/dnsinterceptor/handler.go | 4 +- dns/dns.go | 4 + management/server/dns.go | 12 +- management/server/dns_test.go | 24 +- management/server/peer_test.go | 2 +- shared/management/proto/management.proto | 2 +- 25 files changed, 1125 insertions(+), 196 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 81f7a9125..16b50211e 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 081991235..80aea7cf8 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nberrors.FormatErrorOrNil(merr) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + dnatRule := []string{ + "-i", r.wgIface.Name(), + "-p", strings.ToLower(string(protocol)), + "--dport", strconv.Itoa(int(sourcePort)), + "-d", localAddr.String(), + "-m", "addrtype", "--dst-type", "LOCAL", + "-j", "DNAT", + "--to-destination", ":" + strconv.Itoa(int(targetPort)), + } + + ruleInfo := ruleInfo{ + table: tableNat, + chain: chainRTRDR, + rule: dnatRule, + } + + if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + r.rules[ruleID] = ruleInfo.rule + + r.updateState() + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if dnatRule, exists := r.rules[ruleID]; exists { + if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { + return fmt.Errorf("delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + r.updateState() + return nil +} + func applyPort(flag string, port *firewall.Port) []string { if port == nil { return nil diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 3b3164823..7ee33118b 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -151,14 +151,20 @@ type Manager interface { DisableRouting() error - // AddDNATRule adds a DNAT rule + // AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network. AddDNATRule(ForwardRule) (Rule, error) - // DeleteDNATRule deletes a DNAT rule + // DeleteDNATRule deletes the outbound DNAT rule. DeleteDNATRule(Rule) error // UpdateSet updates the set with the given prefixes UpdateSet(hash Set, prefixes []netip.Prefix) error + + // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services + AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // RemoveInboundDNAT removes inbound DNAT rule + RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 560f224f5..aa90d3b9a 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e918d0524..648a6aedf 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + protoNum, err := protoToInt(protocol) + if err != nil { + return fmt.Errorf("convert protocol to number: %w", err) + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 3, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 3, + Data: binaryutil.BigEndian.PutUint16(sourcePort), + }, + } + + exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + + exprs = append(exprs, + &expr.Immediate{ + Register: 1, + Data: localAddr.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(targetPort), + }, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: 2, + RegProtoMax: 0, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingRdr], + Exprs: exprs, + UserData: []byte(ruleID), + } + r.conn.AddRule(dnatRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + + r.rules[ruleID] = dnatRule + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if rule, exists := r.rules[ruleID]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + return nil +} + // applyNetwork generates nftables expressions for networks (CIDR) or sets func (r *router) applyNetwork( network firewall.Network, diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index bcf6d894b..7be0dd78f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -22,6 +22,8 @@ type BaseConnTrack struct { PacketsRx atomic.Uint64 BytesTx atomic.Uint64 BytesRx atomic.Uint64 + + DNATOrigPort atomic.Uint32 } // these small methods will be inlined by the compiler diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a2355e5c7..8d64412e0 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui if exists { t.updateState(key, conn, flags, direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } -// TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) +// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed +func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 { + if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) +func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) +func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists || flags&TCPSyn == 0 { return } @@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.tombstone.Store(false) conn.state.Store(int32(TCPStateNew)) + conn.DNATOrigPort.Store(uint32(origPort)) - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() @@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() { } } +// GetConnection safely retrieves a connection state +func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn, exists := t.connections[key] + return conn, exists +} + // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { t.tickerCancel() diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index d01a8db4f..bb440f70a 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { serverPort := uint16(80) // 1. Client sends SYN (we receive it as inbound) - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0) key := ConnKey{ SrcIP: clientIP, @@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) // 3. Client sends ACK to complete handshake - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") // 4. Test data transfer // Client sends data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0) // Server sends ACK for data tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) @@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) // Client sends ACK for data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) // Verify state and counters require.Equal(t, TCPStateEstablished, conn.GetState()) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e7f49c46f..a3b6a418b 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -// TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) +// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 { + _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size) + if exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort) } -func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort if exists { conn.UpdateLastSeen() conn.UpdateCounters(direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return } @@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d SourcePort: srcPort, DestPort: dstPort, } + conn.DNATOrigPort.Store(uint32(origPort)) conn.UpdateLastSeen() conn.UpdateCounters(direction, size) @@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } t.sendEvent(nftypes.TypeStart, conn, ruleID) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 7eef49e31..fbc39b740 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -109,6 +109,10 @@ type Manager struct { dnatMappings map[netip.Addr]netip.Addr dnatMutex sync.RWMutex dnatBiMap *biDNATMap + + portDNATEnabled atomic.Bool + portDNATRules []portDNATRule + portDNATMutex sync.RWMutex } // decoder for packages @@ -122,6 +126,8 @@ type decoder struct { icmp6 layers.ICMPv6 decoded []gopacket.LayerType parser *gopacket.DecodingLayerParser + + dnatOrigPort uint16 } // Create userspace firewall manager constructor @@ -196,6 +202,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), + portDNATRules: []portDNATRule{}, } m.routingEnabled.Store(false) @@ -630,7 +637,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return true } - m.trackOutbound(d, srcIP, dstIP, size) + m.trackOutbound(d, srcIP, dstIP, packetData, size) m.translateOutboundDNAT(packetData, d) return false @@ -674,14 +681,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + if origPort == 0 { + break + } + if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite UDP port: %v", err) + } case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + if origPort == 0 { + break + } + if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite TCP port: %v", err) + } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) } @@ -691,13 +710,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) } + + d.dnatOrigPort = 0 } // udpHooksDrop checks if any UDP hooks should drop the packet @@ -759,10 +780,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { return false } + // TODO: optimize port DNAT by caching matched rules in conntrack + if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { + // Re-decode after port DNAT translation to update port information + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) + return true + } + srcIP, dstIP = m.extractIPs(d) + } + if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index 5614e2ec3..139f702f2 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -50,6 +50,8 @@ type logMessage struct { arg4 any arg5 any arg6 any + arg7 any + arg8 any } // Logger is a high-performance, non-blocking logger @@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } - func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { select { @@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) { } } +func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { + if l.level.Load() >= uint32(LevelDebug) { + select { + case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + default: + } + } +} + func (l *Logger) Trace1(format string, arg1 any) { if l.level.Load() >= uint32(LevelTrace) { select { @@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { } } +// Trace8 logs a trace message with 8 arguments (8 placeholder in format string) +func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: + default: + } + } +} + func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = (*buf)[:0] *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") @@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { argCount++ if msg.arg6 != nil { argCount++ + if msg.arg7 != nil { + argCount++ + if msg.arg8 != nil { + argCount++ + } + } } } } @@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) case 6: formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) + case 7: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7) + case 8: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8) } *buf = append(*buf, formatted...) @@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error { case <-done: return nil } -} \ No newline at end of file +} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 27b752531..13567872e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "net/netip" + "slices" + "github.com/google/gopacket" "github.com/google/gopacket/layers" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -13,6 +15,21 @@ import ( var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") +var ( + errInvalidIPHeaderLength = errors.New("invalid IP header length") +) + +const ( + // Port offsets in TCP/UDP headers + sourcePortOffset = 0 + destinationPortOffset = 2 + + // IP address offsets in IPv4 header + sourceIPOffset = 12 + destinationIPOffset = 16 +) + +// ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { return 0 @@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 { return ^uint16(sum) } +// icmpChecksum calculates ICMP checksum. func icmpChecksum(data []byte) uint16 { var sum1, sum2, sum3, sum4 uint32 i := 0 @@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 { return ^uint16(sum) } +// biDNATMap maintains bidirectional DNAT mappings. type biDNATMap struct { forward map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr } +// portDNATRule represents a port-specific DNAT rule. +type portDNATRule struct { + protocol gopacket.LayerType + origPort uint16 + targetPort uint16 + targetIP netip.Addr +} + +// newBiDNATMap creates a new bidirectional DNAT mapping structure. func newBiDNATMap() *biDNATMap { return &biDNATMap{ forward: make(map[netip.Addr]netip.Addr), @@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap { } } +// set adds a bidirectional DNAT mapping between original and translated addresses. func (b *biDNATMap) set(original, translated netip.Addr) { b.forward[original] = translated b.reverse[translated] = original } +// delete removes a bidirectional DNAT mapping for the given original address. func (b *biDNATMap) delete(original netip.Addr) { if translated, exists := b.forward[original]; exists { delete(b.forward, original) @@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) { } } +// getTranslated returns the translated address for a given original address. func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { translated, exists := b.forward[original] return translated, exists } +// getOriginal returns the original address for a given translated address. func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { original, exists := b.reverse[translated] return original, exists } +// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation. func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { - if !originalAddr.IsValid() || !translatedAddr.IsValid() { - return fmt.Errorf("invalid IP addresses") + if !originalAddr.IsValid() { + return fmt.Errorf("invalid original IP address") + } + if !translatedAddr.IsValid() { + return fmt.Errorf("invalid translated IP address") } if m.localipmanager.IsLocalIP(translatedAddr) { @@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr m.dnatMutex.Lock() defer m.dnatMutex.Unlock() - // Initialize both maps together if either is nil if m.dnatMappings == nil || m.dnatBiMap == nil { m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatBiMap = newBiDNATMap() @@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr return nil } -// RemoveInternalDNATMapping removes a 1:1 IP address mapping +// RemoveInternalDNATMapping removes a 1:1 IP address mapping. func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { m.dnatMutex.Lock() defer m.dnatMutex.Unlock() @@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { return nil } -// getDNATTranslation returns the translated address if a mapping exists +// getDNATTranslation returns the translated address if a mapping exists. func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return addr, false @@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { return translated, exists } -// findReverseDNATMapping finds original address for return traffic +// findReverseDNATMapping finds original address for return traffic. func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return translatedAddr, false @@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, return original, exists } -// translateOutboundDNAT applies DNAT translation to outbound packets +// translateOutboundDNAT applies DNAT translation to outbound packets. func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) translatedIP, exists := m.getDNATTranslation(dstIP) @@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error1("Failed to rewrite packet destination: %v", err) + if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return true } -// translateInboundReverse applies reverse DNAT to inbound return traffic +// translateInboundReverse applies reverse DNAT to inbound return traffic. func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) originalIP, exists := m.findReverseDNATMapping(srcIP) @@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error1("Failed to rewrite packet source: %v", err) + if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketDestination replaces destination IP in the packet -func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { +// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { + if !newIP.Is4() { return ErrIPv4Only } - var oldDst [4]byte - copy(oldDst[:], packetData[16:20]) - newDst := newIP.As4() + var oldIP [4]byte + copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + newIPBytes := newIP.As4() - copy(packetData[16:20], newDst[:]) + copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") + return errInvalidIPHeaderLength } binary.BigEndian.PutUint16(packetData[10:12], 0) @@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) - case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) - } - } - - return nil -} - -// rewritePacketSource replaces the source IP address in the packet -func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { - return ErrIPv4Only - } - - var oldSrc [4]byte - copy(oldSrc[:], packetData[12:16]) - newSrc := newIP.As4() - - copy(packetData[12:16], newSrc[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") - } - - binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) - - if len(d.decoded) > 1 { - switch d.decoded[1] { - case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) + m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: m.updateICMPChecksum(packetData, ipHeaderLen) } @@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip return nil } +// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624. func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { tcpStart := ipHeaderLen if len(packetData) < tcpStart+18 { @@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624. func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { udpStart := ipHeaderLen if len(packetData) < udpStart+8 { @@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateICMPChecksum recalculates ICMP checksum after packet modification. func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { icmpStart := ipHeaderLen if len(packetData) < icmpStart+8 { @@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } -// incrementalUpdate performs incremental checksum update per RFC 1624 +// incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { return ^uint16(sum) } -// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) +// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network. func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errNatNotSupported @@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) return m.nativeFirewall.AddDNATRule(rule) } -// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) +// DeleteDNATRule deletes outbound DNAT rule. func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { if m.nativeFirewall == nil { return errNatNotSupported } return m.nativeFirewall.DeleteDNATRule(rule) } + +// addPortRedirection adds a port redirection rule. +func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + rule := portDNATRule{ + protocol: protocol, + origPort: sourcePort, + targetPort: targetPort, + targetIP: targetIP, + } + + m.portDNATRules = append(m.portDNATRules, rule) + m.portDNATEnabled.Store(true) + + return nil +} + +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// removePortRedirection removes a port redirection rule. +func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { + return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + }) + + if len(m.portDNATRules) == 0 { + m.portDNATEnabled.Store(false) + } + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. +func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { + if !m.portDNATEnabled.Load() { + return false + } + + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort := uint16(d.tcp.DstPort) + return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort) + case layers.LayerTypeUDP: + dstPort := uint16(d.udp.DstPort) + return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort) + default: + return false + } +} + +type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error + +func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool { + m.portDNATMutex.RLock() + defer m.portDNATMutex.RUnlock() + + for _, rule := range m.portDNATRules { + if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 { + continue + } + + if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 { + return false + } + + if rule.origPort != port { + continue + } + + if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { + m.logger.Error1("failed to rewrite port: %v", err) + return false + } + d.dnatOrigPort = rule.origPort + return true + } + return false +} + +// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. +func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + tcpStart := ipHeaderLen + if len(packetData) < tcpStart+4 { + return fmt.Errorf("packet too short for TCP header") + } + + portStart := tcpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + if len(packetData) >= tcpStart+18 { + checksumOffset := tcpStart + 16 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + + return nil +} + +// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. +func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return fmt.Errorf("packet too short for UDP header") + } + + portStart := udpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + checksumOffset := udpStart + 6 + if len(packetData) >= udpStart+8 { + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + if oldChecksum != 0 { + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + } + + return nil +} diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index 16dba682e..d726474cf 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) { } }) } + +// BenchmarkPortDNAT measures the performance of port DNAT operations +func BenchmarkPortDNAT(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + setupDNAT bool + useMatchPort bool + description string + }{ + { + name: "tcp_inbound_dnat_match", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: true, + description: "TCP inbound port DNAT translation (22 → 22022)", + }, + { + name: "tcp_inbound_dnat_nomatch", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: false, + description: "TCP inbound with DNAT configured but no port match", + }, + { + name: "tcp_inbound_no_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: false, + useMatchPort: false, + description: "TCP inbound without DNAT (baseline)", + }, + { + name: "udp_inbound_dnat_match", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: true, + description: "UDP inbound port DNAT translation (5353 → 22054)", + }, + { + name: "udp_inbound_dnat_nomatch", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: false, + description: "UDP inbound with DNAT configured but no port match", + }, + { + name: "udp_inbound_no_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: false, + useMatchPort: false, + description: "UDP inbound without DNAT (baseline)", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + var origPort, targetPort, testPort uint16 + if sc.proto == layers.IPProtocolTCP { + origPort, targetPort = 22, 22022 + } else { + origPort, targetPort = 5353, 22054 + } + + if sc.useMatchPort { + testPort = origPort + } else { + testPort = 443 // Different port + } + + // Setup port DNAT mapping if needed + if sc.setupDNAT { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort) + require.NoError(b, err) + } + + // Pre-establish inbound connection for outbound reverse test + if sc.setupDNAT && sc.useMatchPort { + inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort) + manager.filterInbound(inboundPacket, 0) + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark inbound DNAT translation + b.Run("inbound", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time + packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort) + manager.filterInbound(packet, 0) + } + }) + + // Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches) + if sc.setupDNAT && sc.useMatchPort { + b.Run("outbound_reverse", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh return packet (from target port) + packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321) + manager.filterOutbound(packet, 0) + } + }) + } + }) + } +} diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 710abd445..2a285484c 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/device" ) @@ -143,3 +144,111 @@ func TestDNATMappingManagement(t *testing.T) { err = manager.RemoveInternalDNATMapping(originalIP) require.Error(t, err, "Should error when removing non-existent mapping") } + +func TestInboundPortDNAT(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + testCases := []struct { + name string + protocol layers.IPProtocol + sourcePort uint16 + targetPort uint16 + }{ + {"TCP SSH", layers.IPProtocolTCP, 22, 22022}, + {"UDP DNS", layers.IPProtocolUDP, 5353, 22054}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + + inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort) + d := parsePacket(t, inboundPacket) + + translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr) + require.True(t, translated, "Inbound packet should be translated") + + d = parsePacket(t, inboundPacket) + var dstPort uint16 + switch tc.protocol { + case layers.IPProtocolTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.IPProtocolUDP: + dstPort = uint16(d.udp.DstPort) + } + + require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port") + + err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + }) + } +} + +func TestInboundPortDNATNegative(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022) + require.NoError(t, err) + + testCases := []struct { + name string + protocol layers.IPProtocol + srcIP netip.Addr + dstIP netip.Addr + srcPort uint16 + dstPort uint16 + }{ + {"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80}, + {"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22}, + {"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22}, + {"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort) + d := parsePacket(t, packet) + + translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP) + require.False(t, translated, "Packet should NOT be translated for %s", tc.name) + + d = parsePacket(t, packet) + if tc.protocol == layers.IPProtocolTCP { + require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") + } else if tc.protocol == layers.IPProtocolUDP { + require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") + } + }) + } +} + +func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol { + switch proto { + case layers.IPProtocolTCP: + return firewall.ProtocolTCP + case layers.IPProtocolUDP: + return firewall.ProtocolUDP + default: + return firewall.ProtocolALL + } +} diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index c75c0249d..c46a6581d 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -16,25 +16,33 @@ type PacketStage int const ( StageReceived PacketStage = iota + StageInboundPortDNAT + StageInbound1to1NAT StageConntrack StagePeerACL StageRouting StageRouteACL StageForwarding StageCompleted + StageOutbound1to1NAT + StageOutboundPortReverse ) const msgProcessingCompleted = "Processing completed" func (s PacketStage) String() string { return map[PacketStage]string{ - StageReceived: "Received", - StageConntrack: "Connection Tracking", - StagePeerACL: "Peer ACL", - StageRouting: "Routing", - StageRouteACL: "Route ACL", - StageForwarding: "Forwarding", - StageCompleted: "Completed", + StageReceived: "Received", + StageInboundPortDNAT: "Inbound Port DNAT", + StageInbound1to1NAT: "Inbound 1:1 NAT", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + StageOutbound1to1NAT: "Outbound 1:1 NAT", + StageOutboundPortReverse: "Outbound DNAT Reverse", }[s] } @@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa } func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { + if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) { + return trace + } + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { return trace } @@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str } func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { - // will create or update the connection state + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageCompleted, "Packet dropped - decode error", false) + return trace + } + + m.handleOutboundDNAT(trace, packetData, d) + dropped := m.filterOutbound(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) @@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr } return trace } + +func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { + portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) + if portDNATApplied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + trace.DestinationPort = m.getDestPort(d) + } + + nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) + if nat1to1Applied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + } + + return false +} + +func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true) + return false + } + + protocol := d.decoded[1] + if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP { + trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + var originalPort uint16 + if protocol == layers.LayerTypeTCP { + originalPort = uint16(d.tcp.DstPort) + } else { + originalPort = uint16(d.udp.DstPort) + } + + translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP) + if translated { + ipHeaderLen := int((packetData[0] & 0x0F) * 4) + translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3]) + + protoStr := "TCP" + if protocol == layers.LayerTypeUDP { + protoStr = "UDP" + } + msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort) + trace.AddResult(StageInboundPortDNAT, msg, true) + return true + } + + trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true) + return false +} + +func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + + translated := m.translateInboundReverse(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatBiMap.getOriginal(srcIP) + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP) + trace.AddResult(StageInbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) { + m.traceOutbound1to1NAT(trace, packetData, d) + m.traceOutboundPortReverse(trace, packetData, d) +} + +func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + translated := m.translateOutboundDNAT(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatMappings[dstIP] + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP) + trace.AddResult(StageOutbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + var origPort uint16 + transport := d.decoded[1] + switch transport { + case layers.LayerTypeTCP: + srcPort := uint16(d.tcp.SrcPort) + dstPort := uint16(d.tcp.DstPort) + conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + case layers.LayerTypeUDP: + srcPort := uint16(d.udp.SrcPort) + dstPort := uint16(d.udp.DstPort) + conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + default: + trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true) + return false + } + + trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true) + return false +} + +func (m *Manager) getDestPort(d *decoder) uint16 { + if len(d.decoded) < 2 { + return 0 + } + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.DstPort) + default: + return 0 + } +} diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 46c115787..ee1bb8a23 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageCompleted, @@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageCompleted, }, @@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: true, @@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageRouting, StagePeerACL, StageCompleted, diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a3a4ba40f..a1c0dff98 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" - "sync" + "net/netip" + "os" + "strconv" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -12,18 +14,14 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -var ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - listenPort uint16 = 5353 - listenPortMu sync.RWMutex -) - const ( - dnsTTL = 60 //seconds + dnsTTL = 60 + envServerPort = "NB_DNS_FORWARDER_PORT" ) // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. @@ -36,28 +34,30 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status + localAddr netip.Addr + serverPort uint16 fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder } -func ListenPort() uint16 { - listenPortMu.RLock() - defer listenPortMu.RUnlock() - return listenPort -} +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager { + serverPort := nbdns.ForwarderServerPort + if envPort := os.Getenv(envServerPort); envPort != "" { + if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { + serverPort = uint16(port) + log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort) + } else { + log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort) + } + } -func SetListenPort(port uint16) { - listenPortMu.Lock() - listenPort = port - listenPortMu.Unlock() -} - -func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, + localAddr: localAddr, + serverPort: serverPort, } } @@ -71,7 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS UDP DNAT rule: %v", err) + } else { + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } + + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + } else { + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -96,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error { } var mErr *multierror.Error + + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) + } + + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + } + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -111,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort()}, + Values: []uint16{m.serverPort}, } if m.firewall == nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index bebf04f6c..8f75c0646 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -202,9 +202,6 @@ type Engine struct { // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitorWg sync.WaitGroup - - // dns forwarder port - dnsFwdPort uint16 } // Peer is an instance of the Connection Peer @@ -247,7 +244,6 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - dnsFwdPort: dnsfwd.ListenPort(), } sm := profilemanager.NewServiceManager("") @@ -1084,7 +1080,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1843,16 +1839,11 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, - forwarderPort uint16, ) { if e.config.DisableServerRoutes { return } - if forwarderPort > 0 { - dnsfwd.SetListenPort(forwarderPort) - } - if !enabled { if e.dnsForwardMgr == nil { return @@ -1864,20 +1855,17 @@ func (e *Engine) updateDNSForwarder( } if len(fwdEntries) > 0 { - switch { - case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) + if e.dnsForwardMgr == nil { + localAddr := e.wgInterface.Address().IP + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - case e.dnsFwdPort != forwarderPort: - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - e.restartDnsFwd(fwdEntries, forwarderPort) - e.dnsFwdPort = forwarderPort - default: + log.Infof("started domain router service with %d entries", len(fwdEntries)) + } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { @@ -1887,20 +1875,6 @@ func (e *Engine) updateDNSForwarder( } e.dnsForwardMgr = nil } - -} - -func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - // stop and start the forwarder to apply the new port - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 899faf108..a033a2a7c 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -10,10 +10,10 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/dns" ) type rcvChan chan *types.EventFields @@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && + (event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) { return false } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 47c2ffcda..a8e697626 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -18,9 +18,9 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" + pkgdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), pkgdns.ForwarderClientPort) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/dns/dns.go b/dns/dns.go index f889a32ec..40586f24d 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -19,6 +19,10 @@ const ( RootZone = "." // DefaultClass is the class supported by the system DefaultClass = "IN" + // ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort. + ForwarderClientPort uint16 = 5353 + // ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here. + ForwarderServerPort uint16 = 22054 ) const invalidHostLabel = "[^a-zA-Z0-9-]+" diff --git a/management/server/dns.go b/management/server/dns.go index 534f43ec6..e5166ce47 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -21,8 +21,8 @@ import ( ) const ( - dnsForwarderPort = 22054 - oldForwarderPort = 5353 + dnsForwarderPort = nbdns.ForwarderServerPort + oldForwarderPort = nbdns.ForwarderClientPort ) const dnsForwarderPortMinVersion = "v0.59.0" @@ -196,7 +196,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID // If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { if len(peers) == 0 { - return oldForwarderPort + return int64(oldForwarderPort) } reqVer := semver.Canonical(requiredVersion) @@ -211,17 +211,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) if peerVersion == "" { // If any peer doesn't have version info, return 0 - return oldForwarderPort + return int64(oldForwarderPort) } // Compare versions if semver.Compare(peerVersion, reqVer) < 0 { - return oldForwarderPort + return int64(oldForwarderPort) } } // All peers have the required version or newer - return dnsForwarderPort + return int64(dnsForwarderPort) } // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 83caf74ef..96f73a390 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) @@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) } @@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) + result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) { // Test with empty peers list peers := []*nbpeer.Peer{} result := computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) } @@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) } @@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != dnsForwarderPort { + if result != int64(dnsForwarderPort) { t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) } @@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) } @@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) } @@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result == oldForwarderPort { + if result == int64(oldForwarderPort) { t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) } @@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) } } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index fd795b926..3b2ab87fc 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index ad82d37d9..3982ea2af 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -410,7 +410,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; - int64 ForwarderPort = 4; + int64 ForwarderPort = 4 [deprecated = true]; } // CustomZone represents a dns.CustomZone From d7321c130b56bc831bfdd8cbec5ff211d37e7de4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 28 Oct 2025 16:11:35 +0100 Subject: [PATCH 045/120] [client] The status cmd will not be blocked by the ICE probe (#4597) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The status cmd will not be blocked by the ICE probe Refactor the TURN and STUN probe, and cache the results. The NetBird status command will indicate a "checking…" state. --- client/cmd/debug.go | 4 +- client/cmd/status.go | 6 +- client/internal/engine.go | 20 ++-- client/internal/relay/relay.go | 211 +++++++++++++++++++++++++++++---- client/server/server.go | 9 +- client/status/status.go | 11 +- 6 files changed, 216 insertions(+), 45 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index d53c5f06b..430012a17 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { client := proto.NewDaemonServiceClient(conn) - stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) + stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true}) if err != nil { return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) } @@ -303,7 +303,7 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error { func getStatusOutput(cmd *cobra.Command, anon bool) string { var statusOutputString string - statusResp, err := getStatus(cmd.Context()) + statusResp, err := getStatus(cmd.Context(), true) if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { diff --git a/client/cmd/status.go b/client/cmd/status.go index 723f2367c..6e57ceb89 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(cmd.Context()) - resp, err := getStatus(ctx) + resp, err := getStatus(ctx, false) if err != nil { return err } @@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } -func getStatus(ctx context.Context) (*proto.StatusResponse, error) { +func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ @@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) { } defer conn.Close() - resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) + resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) if err != nil { return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 8f75c0646..48a23f4ad 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -202,6 +202,8 @@ type Engine struct { // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitorWg sync.WaitGroup + + probeStunTurn *relay.StunTurnProbe } // Peer is an instance of the Connection Peer @@ -244,6 +246,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), } sm := profilemanager.NewServiceManager("") @@ -1663,7 +1666,7 @@ func (e *Engine) getRosenpassAddr() string { // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // and updates the status recorder with the latest states. -func (e *Engine) RunHealthProbes() bool { +func (e *Engine) RunHealthProbes(waitForResult bool) bool { e.syncMsgMux.Lock() signalHealthy := e.signal.IsHealthy() @@ -1695,8 +1698,12 @@ func (e *Engine) RunHealthProbes() bool { } e.syncMsgMux.Unlock() - - results := e.probeICE(stuns, turns) + var results []relay.ProbeResult + if waitForResult { + results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns) + } else { + results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns) + } e.statusRecorder.UpdateRelayStates(results) relayHealthy := true @@ -1713,13 +1720,6 @@ func (e *Engine) RunHealthProbes() bool { return allHealthy } -func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { - return append( - relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), - relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)..., - ) -} - // restartEngine restarts the engine by cancelling the client context func (e *Engine) restartEngine() { e.syncMsgMux.Lock() diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index fa208716f..693ea1f31 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -2,6 +2,8 @@ package relay import ( "context" + "crypto/sha256" + "errors" "fmt" "net" "sync" @@ -15,6 +17,15 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +const ( + DefaultCacheTTL = 20 * time.Second + probeTimeout = 6 * time.Second +) + +var ( + ErrCheckInProgress = errors.New("probe check is already in progress") +) + // ProbeResult holds the info about the result of a relay probe request type ProbeResult struct { URI string @@ -22,8 +33,164 @@ type ProbeResult struct { Addr string } +type StunTurnProbe struct { + cacheResults []ProbeResult + cacheTimestamp time.Time + cacheKey string + cacheTTL time.Duration + probeInProgress bool + probeDone chan struct{} + mu sync.Mutex +} + +func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe { + return &StunTurnProbe{ + cacheTTL: cacheTTL, + } +} + +func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + if p.probeInProgress { + doneChan := p.probeDone + p.mu.Unlock() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-doneChan: + return p.getCachedResults(cacheKey, stuns, turns) + } + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + p.mu.Unlock() + + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + + return p.getCachedResults(cacheKey, stuns, turns) +} + +// ProbeAll probes all given servers asynchronously and returns the results +func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + + if results := p.checkCache(cacheKey); results != nil { + p.mu.Unlock() + return results + } + + if p.probeInProgress { + p.mu.Unlock() + return createErrorResults(stuns, turns) + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + log.Infof("started new probe for STUN, TURN servers") + go func() { + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + }() + + p.mu.Unlock() + + timer := time.NewTimer(1300 * time.Millisecond) + defer timer.Stop() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-probeDone: + // when the probe is return fast, return the results right away + return p.getCachedResults(cacheKey, stuns, turns) + case <-timer.C: + // if the probe takes longer than 1.3s, return error results to avoid blocking + return createErrorResults(stuns, turns) + } +} + +func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult { + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + age := time.Since(p.cacheTimestamp) + if age < p.cacheTTL { + results := append([]ProbeResult(nil), p.cacheResults...) + log.Debugf("returning cached probe results (age: %v)", age) + return results + } + } + return nil +} + +func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + p.mu.Lock() + defer p.mu.Unlock() + + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + return append([]ProbeResult(nil), p.cacheResults...) + } + return createErrorResults(stuns, turns) +} + +func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) { + defer func() { + p.mu.Lock() + p.probeInProgress = false + p.mu.Unlock() + }() + results := make([]ProbeResult, len(stuns)+len(turns)) + + var wg sync.WaitGroup + for i, uri := range stuns { + wg.Add(1) + go func(idx int, stunURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = stunURI.String() + results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI) + }(i, uri) + } + + stunOffset := len(stuns) + for i, uri := range turns { + wg.Add(1) + go func(idx int, turnURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = turnURI.String() + results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI) + }(stunOffset+i, uri) + } + + wg.Wait() + + p.mu.Lock() + p.cacheResults = results + p.cacheTimestamp = time.Now() + p.cacheKey = cacheKey + p.mu.Unlock() + + log.Debug("Stored new probe results in cache") +} + // ProbeSTUN tries binding to the given STUN uri and acquiring an address -func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("stun probe error from %s: %s", uri, probeErr) @@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } // ProbeTURN tries allocating a session from the given TURN URI -func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("turn probe error from %s: %s", uri, probeErr) @@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) return relayConn.LocalAddr().String(), nil } -// ProbeAll probes all given servers asynchronously and returns the results -func ProbeAll( - ctx context.Context, - fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error), - relays []*stun.URI, -) []ProbeResult { - results := make([]ProbeResult, len(relays)) +func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + total := len(stuns) + len(turns) + results := make([]ProbeResult, total) - var wg sync.WaitGroup - for i, uri := range relays { - ctx, cancel := context.WithTimeout(ctx, 6*time.Second) - defer cancel() - - wg.Add(1) - go func(res *ProbeResult, stunURI *stun.URI) { - defer wg.Done() - res.URI = stunURI.String() - res.Addr, res.Err = fn(ctx, stunURI) - }(&results[i], uri) + allURIs := append(append([]*stun.URI{}, stuns...), turns...) + for i, uri := range allURIs { + results[i] = ProbeResult{ + URI: uri.String(), + Err: ErrCheckInProgress, + } } - wg.Wait() - return results } + +func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string { + h := sha256.New() + for _, uri := range stuns { + h.Write([]byte(uri.String())) + } + for _, uri := range turns { + h.Write([]byte(uri.String())) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/client/server/server.go b/client/server/server.go index 89f50a1ef..3641e6f92 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1057,10 +1057,7 @@ func (s *Server) Status( s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) if msg.GetFullPeerStatus { - if msg.ShouldRunProbes { - s.runProbes() - } - + s.runProbes(msg.ShouldRunProbes) fullStatus := s.statusRecorder.GetFullStatus() pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() @@ -1070,7 +1067,7 @@ func (s *Server) Status( return &statusResponse, nil } -func (s *Server) runProbes() { +func (s *Server) runProbes(waitForProbeResult bool) { if s.connectClient == nil { return } @@ -1081,7 +1078,7 @@ func (s *Server) runProbes() { } if time.Since(s.lastProbe) > probeThreshold { - if engine.RunHealthProbes() { + if engine.RunHealthProbes(waitForProbeResult) { s.lastProbe = time.Now() } } diff --git a/client/status/status.go b/client/status/status.go index 5e4fcd8dc..8a0b7bae0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" + probeRelay "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/version" @@ -340,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, for _, relay := range overview.Relays.Details { available := "Available" reason := "" + if !relay.Available { - available = "Unavailable" - reason = fmt.Sprintf(", reason: %s", relay.Error) + if relay.Error == probeRelay.ErrCheckInProgress.Error() { + available = "Checking..." + } else { + available = "Unavailable" + reason = fmt.Sprintf(", reason: %s", relay.Error) + } } + relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) } } else { From d3a34adcc948800e7d96b59e2b814348ef5569a0 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:21:40 +0100 Subject: [PATCH 046/120] [client] Fix Connect/Disconnect buttons being enabled or disabled at the same time (#4711) --- client/ui/client_ui.go | 83 +++++++++++++++----------------------- client/ui/event_handler.go | 49 +++++++++++++++++++--- client/ui/profile.go | 18 +++++++-- 3 files changed, 91 insertions(+), 59 deletions(-) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 0043f228e..22f18b948 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -296,6 +296,8 @@ type serviceClient struct { mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window + + connectCancel context.CancelFunc } type menuHandler struct { @@ -592,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form { } } -func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { +func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return nil, err + return nil, fmt.Errorf("get daemon client: %w", err) } activeProf, err := s.profileManager.GetActiveProfile() if err != nil { - log.Errorf("get active profile: %v", err) - return nil, err + return nil, fmt.Errorf("get active profile: %w", err) } currUser, err := user.Current() @@ -610,84 +610,71 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { return nil, fmt.Errorf("get current user: %w", err) } - loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ + loginResp, err := conn.Login(ctx, &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", ProfileName: &activeProf.Name, Username: &currUser.Username, }) if err != nil { - log.Errorf("login to management URL with: %v", err) - return nil, err + return nil, fmt.Errorf("login to management: %w", err) } if loginResp.NeedsSSOLogin && openURL { - err = s.handleSSOLogin(loginResp, conn) - if err != nil { - log.Errorf("handle SSO login failed: %v", err) - return nil, err + if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil { + return nil, fmt.Errorf("SSO login: %w", err) } } return loginResp, nil } -func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := openURL(loginResp.VerificationURIComplete) - if err != nil { - log.Errorf("opening the verification uri in the browser failed: %v", err) - return err +func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { + if err := openURL(loginResp.VerificationURIComplete); err != nil { + return fmt.Errorf("open browser: %w", err) } - resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) + resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) if err != nil { - log.Errorf("waiting sso login failed with: %v", err) - return err + return fmt.Errorf("wait for SSO login: %w", err) } if resp.Email != "" { - err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ + if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ Email: resp.Email, - }) - if err != nil { - log.Warnf("failed to set profile state: %v", err) + }); err != nil { + log.Debugf("failed to set profile state: %v", err) } else { s.mProfile.refresh() } - } return nil } -func (s *serviceClient) menuUpClick() error { +func (s *serviceClient) menuUpClick(ctx context.Context) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { systray.SetTemplateIcon(iconErrorMacOS, s.icError) - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } - _, err = s.login(true) + _, err = s.login(ctx, true) if err != nil { - log.Errorf("login failed with: %v", err) - return err + return fmt.Errorf("login: %w", err) } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status == string(internal.StatusConnected) { - log.Warnf("already connected") return nil } - if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - log.Errorf("up service: %v", err) - return err + if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil { + return fmt.Errorf("start connection: %w", err) } return nil @@ -697,24 +684,20 @@ func (s *serviceClient) menuDownClick() error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { - log.Warnf("already down") return nil } - if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - log.Errorf("down service: %v", err) - return err + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + return fmt.Errorf("stop connection: %w", err) } return nil @@ -1381,7 +1364,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - resp, err := s.login(false) + resp, err := s.login(ctx, false) if err != nil { log.Errorf("failed to fetch login URL: %v", err) return @@ -1401,7 +1384,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) + _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) if err != nil { log.Errorf("Waiting sso login failed with: %v", err) label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") @@ -1409,7 +1392,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } label.SetText("Re-authentication successful.\nReconnecting") - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { log.Errorf("get service status: %v", err) return @@ -1422,7 +1405,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.Up(s.ctx, &proto.UpRequest{}) + _, err = conn.Up(ctx, &proto.UpRequest{}) if err != nil { label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") log.Errorf("Reconnecting failed with: %v", err) diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index e9b7f4f30..e0b619411 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -12,6 +12,8 @@ import ( "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" @@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) { func (h *eventHandler) handleConnectClick() { h.client.mUp.Disable() + + if h.client.connectCancel != nil { + h.client.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(h.client.ctx) + h.client.connectCancel = connectCancel + go func() { - defer h.client.mUp.Enable() - if err := h.client.menuUpClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) + defer connectCancel() + + if err := h.client.menuUpClick(connectCtx); err != nil { + st, ok := status.FromError(err) + if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { + log.Debugf("connect operation cancelled by user") + } else { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect")) + log.Errorf("connect failed: %v", err) + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after connect: %v", err) } }() } func (h *eventHandler) handleDisconnectClick() { h.client.mDown.Disable() + + if h.client.connectCancel != nil { + log.Debugf("cancelling ongoing connect operation") + h.client.connectCancel() + h.client.connectCancel = nil + } + go func() { - defer h.client.mDown.Enable() if err := h.client.menuDownClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon")) + st, ok := status.FromError(err) + if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect")) + log.Errorf("disconnect failed: %v", err) + } else { + log.Debugf("disconnect cancelled or already disconnecting") + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after disconnect: %v", err) } }() } @@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error { } h.client.getSrvConfig() - + return nil } diff --git a/client/ui/profile.go b/client/ui/profile.go index 075223795..74189c9a0 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -387,6 +387,7 @@ type subItem struct { type profileMenu struct { mu sync.Mutex ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem @@ -396,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -404,12 +405,13 @@ type profileMenu struct { type newProfileMenuArgs struct { ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -418,6 +420,7 @@ type newProfileMenuArgs struct { func newProfileMenu(args newProfileMenuArgs) *profileMenu { p := profileMenu{ ctx: args.ctx, + serviceClient: args.serviceClient, profileManager: args.profileManager, eventHandler: args.eventHandler, profileMenuItem: args.profileMenuItem, @@ -569,10 +572,19 @@ func (p *profileMenu) refresh() { } } - if err := p.upClickCallback(); err != nil { + if p.serviceClient.connectCancel != nil { + p.serviceClient.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(p.ctx) + p.serviceClient.connectCancel = connectCancel + + if err := p.upClickCallback(connectCtx); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } + connectCancel() + p.refresh() p.loadSettingsCallback() } From 1ee575befe351fefb400d70fe07d4c7c58742d35 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 28 Oct 2025 22:58:43 +0100 Subject: [PATCH 047/120] [client] Use management-provided dns forwarder port on the client side (#4712) --- client/internal/engine.go | 12 +++++++++++- client/internal/routemanager/common/params.go | 2 ++ .../internal/routemanager/dnsinterceptor/handler.go | 6 ++++-- client/internal/routemanager/manager.go | 11 +++++++++++ client/internal/routemanager/mock.go | 4 ++++ dns/dns.go | 2 ++ 6 files changed, 34 insertions(+), 3 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 48a23f4ad..19d37eee1 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1059,10 +1059,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { + dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network) + + if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { log.Errorf("failed to update dns server, err: %v", err) } + e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort) + // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) @@ -1207,10 +1211,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { + forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) + if forwarderPort == 0 { + forwarderPort = nbdns.ForwarderClientPort + } + dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), CustomZones: make([]nbdns.CustomZone, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0), + ForwarderPort: forwarderPort, } for _, zone := range protoDNSConfig.GetCustomZones() { diff --git a/client/internal/routemanager/common/params.go b/client/internal/routemanager/common/params.go index def18411f..8b5407850 100644 --- a/client/internal/routemanager/common/params.go +++ b/client/internal/routemanager/common/params.go @@ -1,6 +1,7 @@ package common import ( + "sync/atomic" "time" "github.com/netbirdio/netbird/client/firewall/manager" @@ -25,4 +26,5 @@ type HandlerParams struct { UseNewDNSRoute bool Firewall manager.Manager FakeIPManager *fakeip.Manager + ForwarderPort *atomic.Uint32 } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index a8e697626..348338dac 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -8,6 +8,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -20,7 +21,6 @@ import ( nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" - pkgdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -55,6 +55,7 @@ type DnsInterceptor struct { peerStore *peerstore.Store firewall firewall.Manager fakeIPManager *fakeip.Manager + forwarderPort *atomic.Uint32 } func New(params common.HandlerParams) *DnsInterceptor { @@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor { firewall: params.Firewall, fakeIPManager: params.FakeIPManager, interceptedDomains: make(domainMap), + forwarderPort: params.ForwarderPort, } } @@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), pkgdns.ForwarderClientPort) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load())) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d590dba0d..37974cd17 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -10,6 +10,7 @@ import ( "runtime" "slices" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -23,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/client" @@ -54,6 +56,7 @@ type Manager interface { SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string SetFirewall(firewall.Manager) error + SetDNSForwarderPort(port uint16) Stop(stateManager *statemanager.Manager) } @@ -101,6 +104,7 @@ type DefaultManager struct { disableServerRoutes bool activeRoutes map[route.HAUniqueID]client.RouteHandler fakeIPManager *fakeip.Manager + dnsForwarderPort atomic.Uint32 } func NewManager(config ManagerConfig) *DefaultManager { @@ -130,6 +134,7 @@ func NewManager(config ManagerConfig) *DefaultManager { disableServerRoutes: config.DisableServerRoutes, activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), } + dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort)) useNoop := netstack.IsEnabled() || config.DisableClientRoutes dm.setupRefCounters(useNoop) @@ -270,6 +275,11 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error { return nil } +// SetDNSForwarderPort sets the DNS forwarder port for route handlers +func (m *DefaultManager) SetDNSForwarderPort(port uint16) { + m.dnsForwarderPort.Store(uint32(port)) +} + // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() @@ -345,6 +355,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { UseNewDNSRoute: m.useNewDNSRoute, Firewall: m.firewall, FakeIPManager: m.fakeIPManager, + ForwarderPort: &m.dnsForwarderPort, } handler := client.HandlerFromRoute(params) if err := handler.AddRoute(m.ctx); err != nil { diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index be633c3fa..6b06144b2 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error { panic("implement me") } +// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface +func (m *MockManager) SetDNSForwarderPort(port uint16) { +} + // Stop mock implementation of Stop from Manager interface func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { diff --git a/dns/dns.go b/dns/dns.go index 40586f24d..cf089d4ed 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -35,6 +35,8 @@ type Config struct { NameServerGroups []*NameServerGroup // CustomZones contains a list of custom zone CustomZones []CustomZone + // ForwarderPort is the port clients should connect to on routing peers for DNS forwarding + ForwarderPort uint16 } // CustomZone represents a custom zone to be resolved by the dns server From c530db145564a907eef67937b4b1e33eaa3a3ae7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 29 Oct 2025 17:27:18 +0100 Subject: [PATCH 048/120] [client] Fix UI panic when switching profiles (#4718) --- client/ui/client_ui.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 22f18b948..f4350b251 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -833,6 +833,7 @@ func (s *serviceClient) onTrayReady() { newProfileMenuArgs := &newProfileMenuArgs{ ctx: s.ctx, + serviceClient: s, profileManager: s.profileManager, eventHandler: s.eventHandler, profileMenuItem: profileMenuItem, From 43c9a519131ef69c86c74a8d6a579e4f2c1d236d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:14:27 +0100 Subject: [PATCH 049/120] [client] Migrate deprecated grpc client code (#4687) --- client/grpc/dialer.go | 42 ++++++++++++++++++++++++++------ client/grpc/dialer_generic.go | 3 +-- client/net/conn.go | 12 +++++---- client/net/dial.go | 1 - client/net/dialer_dial.go | 3 ++- client/net/listener_listen.go | 4 +-- shared/management/client/grpc.go | 3 +-- shared/signal/client/grpc.go | 3 +-- 8 files changed, 49 insertions(+), 22 deletions(-) diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 6aff53b92..7763f2417 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" + "fmt" "runtime" "time" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -17,6 +20,9 @@ import ( "github.com/netbirdio/netbird/util/embeddedroots" ) +// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready +var ErrConnectionShutdown = errors.New("connection shutdown before ready") + // Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() @@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } +// waitForConnectionReady blocks until the connection becomes ready or fails. +// Returns an error if the connection times out, is cancelled, or enters shutdown state. +func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error { + conn.Connect() + + state := conn.GetState() + for state != connectivity.Ready && state != connectivity.Shutdown { + if !conn.WaitForStateChange(ctx, state) { + return fmt.Errorf("wait state change from %s: %w", state, ctx.Err()) + } + state = conn.GetState() + } + + if state == connectivity.Shutdown { + return ErrConnectionShutdown + } + + return nil +} + // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { @@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone })) } - connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - conn, err := grpc.DialContext( - connCtx, + conn, err := grpc.NewClient( addr, transportOption, WithCustomDialer(tlsEnabled, component), - grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, }), ) if err != nil { - log.Printf("DialContext error: %v", err) + return nil, fmt.Errorf("new client: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := waitForConnectionReady(ctx, conn); err != nil { + _ = conn.Close() return nil, err } diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index 96f347c64..479575996 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { +func WithCustomDialer(_ bool, _ string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() @@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { - log.Errorf("Failed to dial: %s", err) return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil diff --git a/client/net/conn.go b/client/net/conn.go index 918e7f628..bf54c792d 100644 --- a/client/net/conn.go +++ b/client/net/conn.go @@ -17,8 +17,7 @@ type Conn struct { ID hooks.ConnectionID } -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection. func (c *Conn) Close() error { return closeConn(c.ID, c.Conn) } @@ -29,7 +28,7 @@ type TCPConn struct { ID hooks.ConnectionID } -// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection. func (c *TCPConn) Close() error { return closeConn(c.ID, c.TCPConn) } @@ -37,13 +36,16 @@ func (c *TCPConn) Close() error { // closeConn is a helper function to close connections and execute close hooks. func closeConn(id hooks.ConnectionID, conn io.Closer) error { err := conn.Close() + cleanupConnID(id) + return err +} +// cleanupConnID executes close hooks for a connection ID. +func cleanupConnID(id hooks.ConnectionID) { closeHooks := hooks.GetCloseHooks() for _, hook := range closeHooks { if err := hook(id); err != nil { log.Errorf("Error executing close hook: %v", err) } } - - return err } diff --git a/client/net/dial.go b/client/net/dial.go index 041a00e5d..17c9ff98a 100644 --- a/client/net/dial.go +++ b/client/net/dial.go @@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro } return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil } - if err := conn.Close(); err != nil { log.Errorf("failed to close connection: %v", err) } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go index 2e1eb53d8..1e275013f 100644 --- a/client/net/dialer_dial.go +++ b/client/net/dialer_dial.go @@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { + cleanupConnID(connID) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } @@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str ips, err := resolver.LookupIPAddr(ctx, host) if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) + return fmt.Errorf("resolve address %s: %w", address, err) } log.Debugf("Dialer resolved IPs for %s: %v", address, ips) diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go index 0bb5ad67d..a150172b4 100644 --- a/client/net/listener_listen.go +++ b/client/net/listener_listen.go @@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection. func (c *PacketConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.PacketConn) @@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.UDPConn.WriteTo(b, addr) } -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection. func (c *UDPConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.UDPConn) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 076f2532b..520a83e36 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 31f3372c0..5368b57a2 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } From 86eff0d75047cd17a450778439f0112e4817a35e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:18:09 +0100 Subject: [PATCH 050/120] [client] Fix netstack dns forwarder (#4727) --- client/firewall/uspfilter/filter.go | 97 +++++++++++++++++++++++- client/internal/dnsfwd/cache_test.go | 1 - client/internal/dnsfwd/forwarder.go | 66 ++++++++++++---- client/internal/dnsfwd/forwarder_test.go | 20 ++--- client/internal/dnsfwd/manager.go | 37 ++++++--- client/internal/engine.go | 70 ++++++++++++----- 6 files changed, 231 insertions(+), 60 deletions(-) diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index fbc39b740..ec2d2c57f 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -50,6 +50,12 @@ const ( var errNatNotSupported = errors.New("nat not supported with userspace firewall") +// serviceKey represents a protocol/port combination for netstack service registry +type serviceKey struct { + protocol gopacket.LayerType + port uint16 +} + // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule @@ -113,6 +119,9 @@ type Manager struct { portDNATEnabled atomic.Bool portDNATRules []portDNATRule portDNATMutex sync.RWMutex + + netstackServices map[serviceKey]struct{} + netstackServiceMutex sync.RWMutex } // decoder for packages @@ -203,6 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), portDNATRules: []portDNATRule{}, + netstackServices: make(map[serviceKey]struct{}), } m.routingEnabled.Store(false) @@ -838,9 +848,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet return true } - // If requested we pass local traffic to internal interfaces to the forwarder. - // netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder. - if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) { + if m.shouldForward(d, dstIP) { return m.handleForwardedLocalTraffic(packetData) } @@ -1274,3 +1282,86 @@ func (m *Manager) DisableRouting() error { return nil } + +// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port +func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + m.netstackServices[key] = struct{}{} + m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType) + m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices)) +} + +// UnregisterNetstackService removes a service from the netstack registry +func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + delete(m.netstackServices, key) + m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port) +} + +// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use +func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType { + switch protocol { + case nftypes.TCP: + return layers.LayerTypeTCP + case nftypes.UDP: + return layers.LayerTypeUDP + case nftypes.ICMP: + return layers.LayerTypeICMPv4 + default: + return gopacket.LayerType(0) // Invalid/unknown + } +} + +// shouldForward determines if a packet should be forwarded to the forwarder. +// The forwarder handles routing packets to the native OS network stack. +// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly. +func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { + // not enabled, never forward + if !m.localForwarding { + return false + } + + // netstack always needs to forward because it's lacking a native interface + // exception for registered netstack services, those should go to netstack listeners + if m.netstack { + return !m.hasMatchingNetstackService(d) + } + + // traffic to our other local interfaces (not NetBird IP) - always forward + if dstIP != m.wgIface.Address().IP { + return true + } + + // traffic to our NetBird IP, not netstack mode - send to netstack listeners + return false +} + +// hasMatchingNetstackService checks if there's a registered netstack service for this packet +func (m *Manager) hasMatchingNetstackService(d *decoder) bool { + if len(d.decoded) < 2 { + return false + } + + var dstPort uint16 + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + dstPort = uint16(d.udp.DstPort) + default: + return false + } + + key := serviceKey{protocol: d.decoded[1], port: dstPort} + m.netstackServiceMutex.RLock() + _, exists := m.netstackServices[key] + m.netstackServiceMutex.RUnlock() + + return exists +} diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go index c23f0f31d..44ebe290b 100644 --- a/client/internal/dnsfwd/cache_test.go +++ b/client/internal/dnsfwd/cache_test.go @@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) { t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) } } - diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 7a262fa4c..aef16a8cf 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -33,7 +34,7 @@ type firewaller interface { } type DNSForwarder struct { - listenAddress string + listenAddress netip.AddrPort ttl uint32 statusRecorder *peer.Status @@ -47,9 +48,11 @@ type DNSForwarder struct { firewall firewaller resolver resolver cache *cache + + wgIface wgIface } -func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, @@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat statusRecorder: statusRecorder, resolver: net.DefaultResolver, cache: newCache(), + wgIface: wgIface, } } func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { - log.Infof("starting DNS forwarder on address=%s", f.listenAddress) + var netstackNet *netstack.Net + if f.wgIface != nil { + netstackNet = f.wgIface.GetNet() + } + + addrDesc := f.listenAddress.String() + if netstackNet != nil { + addrDesc = fmt.Sprintf("netstack %s", f.listenAddress) + } + log.Infof("starting DNS forwarder on address=%s", addrDesc) + + udpLn, err := f.createUDPListener(netstackNet) + if err != nil { + return fmt.Errorf("create UDP listener: %w", err) + } + + tcpLn, err := f.createTCPListener(netstackNet) + if err != nil { + return fmt.Errorf("create TCP listener: %w", err) + } - // UDP server mux := dns.NewServeMux() f.mux = mux mux.HandleFunc(".", f.handleDNSQueryUDP) f.dnsServer = &dns.Server{ - Addr: f.listenAddress, - Net: "udp", - Handler: mux, + PacketConn: udpLn, + Handler: mux, } - // TCP server tcpMux := dns.NewServeMux() f.tcpMux = tcpMux tcpMux.HandleFunc(".", f.handleDNSQueryTCP) f.tcpServer = &dns.Server{ - Addr: f.listenAddress, - Net: "tcp", - Handler: tcpMux, + Listener: tcpLn, + Handler: tcpMux, } f.UpdateDomains(entries) @@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { errCh := make(chan error, 2) go func() { - log.Infof("DNS UDP listener running on %s", f.listenAddress) - errCh <- f.dnsServer.ListenAndServe() + log.Infof("DNS UDP listener running on %s", addrDesc) + errCh <- f.dnsServer.ActivateAndServe() }() go func() { - log.Infof("DNS TCP listener running on %s", f.listenAddress) - errCh <- f.tcpServer.ListenAndServe() + log.Infof("DNS TCP listener running on %s", addrDesc) + errCh <- f.tcpServer.ActivateAndServe() }() - // return the first error we get (e.g. bind failure or shutdown) return <-errCh } +func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) { + if netstackNet != nil { + return netstackNet.ListenUDPAddrPort(f.listenAddress) + } + + return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress)) +} + +func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) { + if netstackNet != nil { + return netstackNet.ListenTCPAddrPort(f.listenAddress) + } + + return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress)) +} + func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index c1c95a2c1..4d0b96a75 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) } - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString(tt.configuredDomain) @@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { mockResolver := &MockResolver{} // Set up forwarder - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Create entries and track sets @@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Configure a single domain @@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) d, err := domain.FromString(tt.configured) require.NoError(t, err) @@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { func TestDNSForwarder_TCPTruncation(t *testing.T) { // Test that large UDP responses are truncated with TC bit set mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, _ := domain.FromString("example.com") @@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { // a subsequent upstream failure still returns a successful response from cache. func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Verifies that cache normalization works across casing and trailing dot variations. func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("ExAmPlE.CoM") @@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Set up complex overlapping patterns @@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) { // Test handling of malformed query with no questions - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) query := &dns.Msg{} // Don't set any question diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a1c0dff98..b26836d17 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -10,9 +10,11 @@ import ( "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -24,6 +26,12 @@ const ( envServerPort = "NB_DNS_FORWARDER_PORT" ) +// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder. +type wgIface interface { + GetNet() *netstack.Net + Address() wgaddr.Address +} + // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. type ForwarderEntry struct { Domain domain.Domain @@ -34,7 +42,7 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status - localAddr netip.Addr + wgIface wgIface serverPort uint16 fwRules []firewall.Rule @@ -42,7 +50,7 @@ type Manager struct { dnsForwarder *DNSForwarder } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager { +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager { serverPort := nbdns.ForwarderServerPort if envPort := os.Getenv(envServerPort); envPort != "" { if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { @@ -56,7 +64,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr neti return &Manager{ firewall: fw, statusRecorder: statusRecorder, - localAddr: localAddr, + wgIface: wgIface, serverPort: serverPort, } } @@ -71,21 +79,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.localAddr.IsValid() && m.firewall != nil { - if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + localAddr := m.wgIface.Address().IP + + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { log.Warnf("failed to add DNS UDP DNAT rule: %v", err) } else { - log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) } - if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { log.Warnf("failed to add DNS TCP DNAT rule: %v", err) } else { - log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) } } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder) + listenAddress := netip.AddrPortFrom(localAddr, m.serverPort) + m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface) + go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -111,12 +123,13 @@ func (m *Manager) Stop(ctx context.Context) error { var mErr *multierror.Error - if m.localAddr.IsValid() && m.firewall != nil { - if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + localAddr := m.wgIface.Address().IP + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) } - if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) } } diff --git a/client/internal/engine.go b/client/internal/engine.go index 19d37eee1..ad69bcf43 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1855,35 +1855,69 @@ func (e *Engine) updateDNSForwarder( } if !enabled { - if e.dnsForwardMgr == nil { - return - } - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } + e.stopDNSForwarder() return } if len(fwdEntries) > 0 { if e.dnsForwardMgr == nil { - localAddr := e.wgInterface.Address().IP - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr) - - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } - - log.Infof("started domain router service with %d entries", len(fwdEntries)) + e.startDNSForwarder(fwdEntries) } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } + e.stopDNSForwarder() + } +} + +func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface) + e.registerDNSServices() + + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { + log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil + return + } + + log.Infof("started domain router service with %d entries", len(fwdEntries)) +} + +func (e *Engine) stopDNSForwarder() { + if e.dnsForwardMgr == nil { + return + } + + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + + e.unregisterDNSServices() + e.dnsForwardMgr = nil +} + +func (e *Engine) registerDNSServices() { + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) + registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) + log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) + } + } +} + +func (e *Engine) unregisterDNSServices() { + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) + registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) + log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) + } } } From 8c108ccad34d1c6aba6a009bd1820fe4ff6d71b5 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 31 Oct 2025 19:19:14 +0100 Subject: [PATCH 051/120] [client] Extend Darwin network monitoring with wakeup detection --- .../networkmonitor/check_change_bsd.go | 77 +-------- .../networkmonitor/check_change_common.go | 92 +++++++++++ .../networkmonitor/check_change_darwin.go | 149 ++++++++++++++++++ client/internal/networkmonitor/monitor.go | 1 + 4 files changed, 246 insertions(+), 73 deletions(-) create mode 100644 client/internal/networkmonitor/check_change_common.go create mode 100644 client/internal/networkmonitor/check_change_darwin.go diff --git a/client/internal/networkmonitor/check_change_bsd.go b/client/internal/networkmonitor/check_change_bsd.go index f5eb2c739..b3482f54e 100644 --- a/client/internal/networkmonitor/check_change_bsd.go +++ b/client/internal/networkmonitor/check_change_bsd.go @@ -1,4 +1,4 @@ -//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd +//go:build dragonfly || freebsd || netbsd || openbsd package networkmonitor @@ -6,21 +6,19 @@ import ( "context" "errors" "fmt" - "syscall" - "unsafe" log "github.com/sirupsen/logrus" - "golang.org/x/net/route" "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { - fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + fd, err := prepareFd() if err != nil { return fmt.Errorf("open routing socket: %v", err) } + defer func() { err := unix.Close(fd) if err != nil && !errors.Is(err, unix.EBADF) { @@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er } }() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - buf := make([]byte, 2048) - n, err := unix.Read(fd, buf) - if err != nil { - if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { - log.Warnf("Network monitor: failed to read from routing socket: %v", err) - } - continue - } - if n < unix.SizeofRtMsghdr { - log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) - continue - } - - msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) - - switch msg.Type { - // handle route changes - case unix.RTM_ADD, syscall.RTM_DELETE: - route, err := parseRouteMessage(buf[:n]) - if err != nil { - log.Debugf("Network monitor: error parsing routing message: %v", err) - continue - } - - if route.Dst.Bits() != 0 { - continue - } - - intf := "" - if route.Interface != nil { - intf = route.Interface.Name - } - switch msg.Type { - case unix.RTM_ADD: - log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) - return nil - case unix.RTM_DELETE: - if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { - log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) - return nil - } - } - } - } - } -} - -func parseRouteMessage(buf []byte) (*systemops.Route, error) { - msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) - if err != nil { - return nil, fmt.Errorf("parse RIB: %v", err) - } - - if len(msgs) != 1 { - return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) - } - - msg, ok := msgs[0].(*route.RouteMessage) - if !ok { - return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) - } - - return systemops.MsgToRoute(msg) + return routeCheck(ctx, fd, nexthopv4, nexthopv6) } diff --git a/client/internal/networkmonitor/check_change_common.go b/client/internal/networkmonitor/check_change_common.go new file mode 100644 index 000000000..c287236e8 --- /dev/null +++ b/client/internal/networkmonitor/check_change_common.go @@ -0,0 +1,92 @@ +//go:build dragonfly || freebsd || netbsd || openbsd || darwin + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "syscall" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func prepareFd() (int, error) { + return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) +} + +func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + buf := make([]byte, 2048) + n, err := unix.Read(fd, buf) + if err != nil { + if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { + log.Warnf("Network monitor: failed to read from routing socket: %v", err) + } + continue + } + if n < unix.SizeofRtMsghdr { + log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + continue + } + + msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) + + switch msg.Type { + // handle route changes + case unix.RTM_ADD, syscall.RTM_DELETE: + route, err := parseRouteMessage(buf[:n]) + if err != nil { + log.Debugf("Network monitor: error parsing routing message: %v", err) + continue + } + + if route.Dst.Bits() != 0 { + continue + } + + intf := "" + if route.Interface != nil { + intf = route.Interface.Name + } + switch msg.Type { + case unix.RTM_ADD: + log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) + return nil + case unix.RTM_DELETE: + if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { + log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) + return nil + } + } + } + } + } +} + +func parseRouteMessage(buf []byte) (*systemops.Route, error) { + msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) + if err != nil { + return nil, fmt.Errorf("parse RIB: %v", err) + } + + if len(msgs) != 1 { + return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) + } + + msg, ok := msgs[0].(*route.RouteMessage) + if !ok { + return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) + } + + return systemops.MsgToRoute(msg) +} diff --git a/client/internal/networkmonitor/check_change_darwin.go b/client/internal/networkmonitor/check_change_darwin.go new file mode 100644 index 000000000..ddc6e1736 --- /dev/null +++ b/client/internal/networkmonitor/check_change_darwin.go @@ -0,0 +1,149 @@ +//go:build darwin && !ios + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "hash/fnv" + "os/exec" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +// todo: refactor to not use static functions + +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + fd, err := prepareFd() + if err != nil { + return fmt.Errorf("open routing socket: %v", err) + } + + defer func() { + if err := unix.Close(fd); err != nil { + if !errors.Is(err, unix.EBADF) { + log.Warnf("Network monitor: failed to close routing socket: %v", err) + } + } + }() + + routeChanged := make(chan struct{}) + go func() { + _ = routeCheck(ctx, fd, nexthopv4, nexthopv6) + close(routeChanged) + }() + + wakeUp := make(chan struct{}) + go func() { + wakeUpListen(ctx) + close(wakeUp) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-routeChanged: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("route change detected") + return nil + case <-wakeUp: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("wakeup detected") + return nil + } +} + +func wakeUpListen(ctx context.Context) { + log.Infof("start to watch for system wakeups") + var ( + initialHash uint32 + err error + ) + + // Keep retrying until initial sysctl succeeds or context is canceled + for { + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + default: + initialHash, err = readSleepTimeHash() + if err != nil { + log.Errorf("failed to detect initial sleep time: %v", err) + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + case <-time.After(3 * time.Second): + continue + } + } + log.Debugf("initial wakeup hash: %d", initialHash) + break + } + break + } + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info("context canceled, stopping wakeUpListen") + return + + case <-ticker.C: + newHash, err := readSleepTimeHash() + if err != nil { + log.Errorf("failed to read sleep time hash: %v", err) + continue + } + + if newHash == initialHash { + log.Tracef("no wakeup detected") + continue + } + + upOut, err := exec.Command("uptime").Output() + if err != nil { + log.Errorf("failed to run uptime command: %v", err) + upOut = []byte("unknown") + } + log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut) + return + } + } +} + +func readSleepTimeHash() (uint32, error) { + cmd := exec.Command("sysctl", "kern.sleeptime") + out, err := cmd.Output() + if err != nil { + return 0, fmt.Errorf("failed to run sysctl: %w", err) + } + + h, err := hash(out) + if err != nil { + return 0, fmt.Errorf("failed to compute hash: %w", err) + } + + return h, nil +} + +func hash(data []byte) (uint32, error) { + hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher + if _, err := hasher.Write(data); err != nil { + return 0, err + } + return hasher.Sum32(), nil +} diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index accdd9c9d..6d019258d 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { event := make(chan struct{}, 1) go nw.checkChanges(ctx, event, nexthop4, nexthop6) + log.Infof("start watching for network changes") // debounce changes timer := time.NewTimer(0) timer.Stop() From a2313a5ba4959a5550c242947a4faf3f9b0cde9b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 15:27:22 +0100 Subject: [PATCH 052/120] [client] Bump github.com/quic-go/quic-go from 0.48.2 to 0.49.1 (#4621) Bumps [github.com/quic-go/quic-go](https://github.com/quic-go/quic-go) from 0.48.2 to 0.49.1. - [Release notes](https://github.com/quic-go/quic-go/releases) - [Commits](https://github.com/quic-go/quic-go/compare/v0.48.2...v0.49.1) --- updated-dependencies: - dependency-name: github.com/quic-go/quic-go dependency-version: 0.49.1 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 06dc921f1..98b158411 100644 --- a/go.mod +++ b/go.mod @@ -76,7 +76,7 @@ require ( github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.22.0 - github.com/quic-go/quic-go v0.48.2 + github.com/quic-go/quic-go v0.49.1 github.com/redis/go-redis/v9 v9.7.3 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 @@ -241,7 +241,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - go.uber.org/mock v0.4.0 // indirect + go.uber.org/mock v0.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect diff --git a/go.sum b/go.sum index ce68ed99e..9705f3aed 100644 --- a/go.sum +++ b/go.sum @@ -590,8 +590,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= -github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= +github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0= +github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -749,8 +749,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8 go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= From 719283c792a46311daa334254c5f90a39e0afa00 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 3 Nov 2025 17:40:12 +0100 Subject: [PATCH 053/120] [management] update db connection lifecycle configuration (#4740) --- management/server/activity/store/sql_store.go | 23 ++++++++++++------- management/server/store/sql_store.go | 6 ++++- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index 80b165938..ffecb6b8f 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -7,6 +7,7 @@ import ( "path/filepath" "runtime" "strconv" + "time" log "github.com/sirupsen/logrus" "gorm.io/driver/postgres" @@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e return nil, err } - if storeEngine == types.SqliteStoreEngine { - sqlDB.SetMaxOpenConns(1) - } else { - conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) - if err != nil { - conns = runtime.NumCPU() - } - sqlDB.SetMaxOpenConns(conns) + conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) + if err != nil { + conns = runtime.NumCPU() } + if storeEngine == types.SqliteStoreEngine { + conns = 1 + } + + sqlDB.SetMaxOpenConns(conns) + sqlDB.SetMaxIdleConns(conns) + sqlDB.SetConnMaxLifetime(time.Hour) + sqlDB.SetConnMaxIdleTime(3 * time.Minute) + + log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) return db, nil } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 382d026c8..4201b68f6 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -89,8 +89,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met } sql.SetMaxOpenConns(conns) + sql.SetMaxIdleConns(conns) + sql.SetConnMaxLifetime(time.Hour) + sql.SetConnMaxIdleTime(3 * time.Minute) - log.WithContext(ctx).Infof("Set max open db connections to %d", conns) + log.WithContext(ctx).Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) if skipMigration { log.WithContext(ctx).Infof("skipping migration") From 679c58ce472d8763d8bea987662104fa6a381dc2 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 4 Nov 2025 17:06:35 +0100 Subject: [PATCH 054/120] [client] Set up networkd to ignore ip rules (#4730) --- client/cmd/service_installer.go | 54 +++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 075ead44e..2a87e538d 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -10,6 +10,8 @@ import ( "path/filepath" "runtime" + log "github.com/sirupsen/logrus" + "github.com/kardianos/service" "github.com/spf13/cobra" @@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error { svcConfig.Option["LogDirectory"] = dir } } + + if err := configureSystemdNetworkd(); err != nil { + log.Warnf("failed to configure systemd-networkd: %v", err) + } } if runtime.GOOS == "windows" { @@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{ return fmt.Errorf("uninstall service: %w", err) } + if runtime.GOOS == "linux" { + if err := cleanupSystemdNetworkd(); err != nil { + log.Warnf("failed to cleanup systemd-networkd configuration: %v", err) + } + } + cmd.Println("NetBird service has been uninstalled") return nil }, @@ -245,3 +257,45 @@ func isServiceRunning() (bool, error) { return status == service.StatusRunning, nil } + +const ( + networkdConfDir = "/etc/systemd/networkd.conf.d" + networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf" + networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing +# routes and policy rules managed by NetBird. + +[Network] +ManageForeignRoutes=no +ManageForeignRoutingPolicyRules=no +` +) + +// configureSystemdNetworkd creates a drop-in configuration file to prevent +// systemd-networkd from removing NetBird's routes and policy rules. +func configureSystemdNetworkd() error { + parentDir := filepath.Dir(networkdConfDir) + if _, err := os.Stat(parentDir); os.IsNotExist(err) { + log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration") + return nil + } + + // nolint:gosec // standard networkd permissions + if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil { + return fmt.Errorf("write networkd configuration: %w", err) + } + + return nil +} + +// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file. +func cleanupSystemdNetworkd() error { + if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) { + return nil + } + + if err := os.Remove(networkdConfFile); err != nil { + return fmt.Errorf("remove networkd configuration: %w", err) + } + + return nil +} From 45c25dca84150c0da85b06875facb01f0bcd119a Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 4 Nov 2025 17:18:51 +0100 Subject: [PATCH 055/120] [client] Clamp MSS on outbound traffic (#4735) --- client/firewall/create.go | 4 +- client/firewall/create_linux.go | 22 +- client/firewall/iptables/manager_linux.go | 5 +- .../firewall/iptables/manager_linux_test.go | 9 +- client/firewall/iptables/router_linux.go | 50 +++- client/firewall/iptables/router_linux_test.go | 14 +- client/firewall/iptables/state_linux.go | 8 +- client/firewall/nftables/manager_linux.go | 5 +- .../firewall/nftables/manager_linux_test.go | 9 +- client/firewall/nftables/router_linux.go | 98 ++++++- client/firewall/nftables/router_linux_test.go | 9 +- client/firewall/nftables/state_linux.go | 8 +- client/firewall/uspfilter/filter.go | 154 ++++++++-- .../firewall/uspfilter/filter_bench_test.go | 277 +++++++++++++++--- .../firewall/uspfilter/filter_filter_test.go | 7 +- client/firewall/uspfilter/filter_test.go | 215 +++++++++++++- .../firewall/uspfilter/forwarder/forwarder.go | 8 +- client/firewall/uspfilter/forwarder/udp.go | 2 +- client/firewall/uspfilter/nat_bench_test.go | 11 +- client/firewall/uspfilter/nat_test.go | 9 +- client/firewall/uspfilter/tracer_test.go | 3 +- client/internal/acl/manager_test.go | 7 +- client/internal/dns/server_test.go | 2 +- client/internal/engine.go | 2 +- 24 files changed, 804 insertions(+), 134 deletions(-) diff --git a/client/firewall/create.go b/client/firewall/create.go index 7b265e1d1..24f12bc6d 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -15,13 +15,13 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index aa2f0d4d1..12dcaee8a 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) + fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu) if !iface.IsUserspaceBind() { return fm, err @@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } - return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) + return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { - fm, err := createFW(iface) +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) { + fm, err := createFW(iface, mtu) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) } @@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, return fm, nil } -func createFW(iface IFaceMapper) (firewall.Manager, error) { +func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) { switch check() { case IPTABLES: log.Info("creating an iptables firewall manager") - return nbiptables.Create(iface) + return nbiptables.Create(iface, mtu) case NFTABLES: log.Info("creating an nftables firewall manager") - return nbnftables.Create(iface) + return nbnftables.Create(iface, mtu) default: log.Info("no firewall manager found, trying to use userspace packet filtering firewall") return nil, errors.New("no firewall manager found") } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) { var errUsp error if fm != nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } else { - fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) } if errUsp != nil { diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 16b50211e..2563a9052 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -36,7 +36,7 @@ type iFaceMapper interface { } // Create iptables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, fmt.Errorf("init iptables: %w", err) @@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouter(iptablesClient, wgIface) + m.router, err = newRouter(iptablesClient, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, } stateManager.RegisterState(state) diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index a5cc62feb..6b5401e2b 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) { require.NoError(t, err) // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 80aea7cf8..305b0bf28 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -30,17 +30,20 @@ const ( chainPOSTROUTING = "POSTROUTING" chainPREROUTING = "PREROUTING" + chainFORWARD = "FORWARD" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWDIN = "NETBIRD-RT-FWD-IN" chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTPRE = "NETBIRD-RT-PRE" chainRTRDR = "NETBIRD-RT-RDR" + chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" jumpManglePre = "jump-mangle-pre" jumpNatPre = "jump-nat-pre" jumpNatPost = "jump-nat-post" + jumpMSSClamp = "jump-mss-clamp" markManglePre = "mark-mangle-pre" markManglePost = "mark-mangle-post" matchSet = "--match-set" @@ -48,6 +51,9 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" fwdSuffix = "_fwd" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) type ruleInfo struct { @@ -77,16 +83,18 @@ type router struct { ipsetCounter *ipsetCounter wgIface iFaceMapper legacyManagement bool + mtu uint16 stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState } -func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { +func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, + mtu: mtu, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) if err != nil { @@ -416,6 +425,7 @@ func (r *router) createContainers() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) @@ -438,6 +448,10 @@ func (r *router) createContainers() error { return fmt.Errorf("add jump rules: %w", err) } + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) + } + return nil } @@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error { return nil } +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + // Add jump rule from FORWARD chain in mangle table to our custom chain + jumpRule := []string{ + "-j", chainRTMSSCLAMP, + } + if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil { + return fmt.Errorf("add jump to MSS clamp chain: %w", err) + } + r.rules[jumpMSSClamp] = jumpRule + + ruleOut := []string{ + "-o", r.wgIface.Name(), + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mss), + } + if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil { + return fmt.Errorf("add outbound MSS clamp rule: %w", err) + } + r.rules["mss-clamp-out"] = ruleOut + + return nil +} + func (r *router) insertEstablishedRule(chain string) error { establishedRule := getConntrackEstablished() @@ -558,7 +601,7 @@ func (r *router) addJumpRules() error { } func (r *router) cleanJumpRules() error { - for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} { + for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} { if rule, exists := r.rules[ruleKey]; exists { var table, chain string switch ruleKey { @@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error { case jumpNatPre: table = tableNat chain = chainPREROUTING + case jumpMSSClamp: + table = tableMangle + chain = chainFORWARD default: return fmt.Errorf("unknown jump rule: %s", ruleKey) } diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 3490c5dad..6707573be 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, manager.init(nil)) @@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { assert.NoError(t, manager.Reset(), "shouldn't return error") }() - // Now 5 rules: // 1. established rule forward in // 2. estbalished rule forward out // 3. jump rule to POST nat chain @@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { // 7. static return masquerade rule // 8. mangle prerouting mark rule // 9. mangle postrouting mark rule - require.Len(t, manager.rules, 9, "should have created rules map") + // 10. jump rule to MSS clamping chain + // 11. MSS clamping rule for outbound traffic + require.Len(t, manager.rules, 11, "should have created rules map") exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) @@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) @@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) defer func() { @@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "Failed to create iptables client") - r, err := newRouter(iptablesClient, ifaceMock) + r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router manager") require.NoError(t, r.init(nil)) diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index 6ef159e01..c88774c1f 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -11,6 +12,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - ipt, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + ipt, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create iptables manager: %w", err) } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index aa90d3b9a..d864914fe 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -44,7 +44,7 @@ type Manager struct { } // Create nftables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { m := &Manager{ rConn: &nftables.Conn{}, wgIface: wgIface, @@ -53,7 +53,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} var err error - m.router, err = newRouter(workTable, wgIface) + m.router, err = newRouter(workTable, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -93,6 +93,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, }); err != nil { log.Errorf("failed to update state: %v", err) diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index c7f05dcb7..adec802c8 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -16,6 +16,7 @@ import ( "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) { func TestNftablesManagerRuleOrder(t *testing.T) { // This test verifies rule insertion order in nftables peer ACLs // We add accept rule first, then deny rule to test ordering behavior - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { stdout, stderr := runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err, "failed to create manager") require.NoError(t, manager.Init(nil)) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 648a6aedf..0a2c79186 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -16,6 +16,7 @@ import ( "github.com/google/nftables/xt" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -32,12 +33,16 @@ const ( chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingRdr = "netbird-rt-redirect" chainNameForward = "FORWARD" + chainNameMangleForward = "netbird-mangle-forward" userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" dnatSuffix = "_dnat" snatSuffix = "_snat" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) const refreshRulesMapError = "refresh rules map: %w" @@ -63,9 +68,10 @@ type router struct { wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool + mtu uint16 } -func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { +func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ conn: &nftables.Conn{}, workTable: workTable, @@ -73,6 +79,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) rules: make(map[string]*nftables.Rule), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), + mtu: mtu, } r.ipsetCounter = refcounter.New( @@ -220,11 +227,23 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeFilter, }) + r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameMangleForward, + Table: r.workTable, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeFilter, + }) + // Add the single NAT rule that matches on mark if err := r.addPostroutingRules(); err != nil { return fmt.Errorf("add single nat rule: %v", err) } + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) + } + if err := r.acceptForwardRules(); err != nil { log.Errorf("failed to add accept rules for the forward chain: %s", err) } @@ -745,6 +764,83 @@ func (r *router) addPostroutingRules() error { return nil } +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + exprsOut := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 13, + Len: 1, + }, + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 1, + Mask: []byte{0x02}, + Xor: []byte{0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0x00}, + }, + &expr.Counter{}, + &expr.Exthdr{ + DestRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + &expr.Cmp{ + Op: expr.CmpOpGt, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Exthdr{ + SourceRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + } + + r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameMangleForward], + Exprs: exprsOut, + }) + + return nil +} + // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { sourceExp, err := r.applyNetwork(pair.Source, nil, true) diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 4fdbf3505..3531b014b 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -17,6 +17,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" ) const ( @@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { // need fw manager to init both acl mgr and router for all chains to be present - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) @@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index f805623d6..48b7b3741 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -3,6 +3,7 @@ package nftables import ( "fmt" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -10,6 +11,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - nft, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + nft, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create nftables manager: %w", err) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index ec2d2c57f..990630ee4 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "errors" "fmt" "net" @@ -27,7 +28,12 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -const layerTypeAll = 0 +const ( + layerTypeAll = 0 + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 +) const ( // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. @@ -36,6 +42,9 @@ const ( // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped. EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING" + // EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic. + EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING" + // EnvForceUserspaceRouter forces userspace routing even if native routing is available. EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" @@ -122,6 +131,10 @@ type Manager struct { netstackServices map[serviceKey]struct{} netstackServiceMutex sync.RWMutex + + mtu uint16 + mssClampValue uint16 + mssClampEnabled bool } // decoder for packages @@ -140,16 +153,16 @@ type decoder struct { } // Create userspace firewall manager constructor -func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - return create(iface, nil, disableServerRoutes, flowLogger) +func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + return create(iface, nil, disableServerRoutes, flowLogger, mtu) } -func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { if nativeFirewall == nil { return nil, errors.New("native firewall is nil") } - mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger) + mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } @@ -157,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall. return mgr, nil } -func parseCreateEnv() (bool, bool) { - var disableConntrack, enableLocalForwarding bool +func parseCreateEnv() (bool, bool, bool) { + var disableConntrack, enableLocalForwarding, disableMSSClamping bool var err error if val := os.Getenv(EnvDisableConntrack); val != "" { disableConntrack, err = strconv.ParseBool(val) @@ -177,12 +190,18 @@ func parseCreateEnv() (bool, bool) { log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err) } } + if val := os.Getenv(EnvDisableMSSClamping); val != "" { + disableMSSClamping, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err) + } + } - return disableConntrack, enableLocalForwarding + return disableConntrack, enableLocalForwarding, disableMSSClamping } -func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - disableConntrack, enableLocalForwarding := parseCreateEnv() +func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv() m := &Manager{ decoders: sync.Pool{ @@ -213,13 +232,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe dnatMappings: make(map[netip.Addr]netip.Addr), portDNATRules: []portDNATRule{}, netstackServices: make(map[serviceKey]struct{}), + mtu: mtu, } m.routingEnabled.Store(false) + if !disableMSSClamping { + m.mssClampEnabled = true + m.mssClampValue = mtu - ipTCPHeaderMinSize + } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { return nil, fmt.Errorf("update local IPs: %w", err) } - if disableConntrack { log.Info("conntrack is disabled") } else { @@ -227,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger) } - - // netstack needs the forwarder for local traffic if m.netstack && m.localForwarding { if err := m.initForwarder(); err != nil { log.Errorf("failed to initialize forwarder: %v", err) } } - if err := iface.SetFilter(m); err != nil { return nil, fmt.Errorf("set filter: %w", err) } @@ -337,7 +357,7 @@ func (m *Manager) initForwarder() error { return errors.New("forwarding not supported") } - forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) + forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu) if err != nil { m.routingEnabled.Store(false) return fmt.Errorf("create forwarder: %w", err) @@ -643,8 +663,17 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return false } - if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { - return true + switch d.decoded[1] { + case layers.LayerTypeUDP: + if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { + return true + } + case layers.LayerTypeTCP: + // Clamp MSS on all TCP SYN packets, including those from local IPs. + // SNATed routed traffic may appear as local IP but still requires clamping. + if m.mssClampEnabled { + m.clampTCPMSS(packetData, d) + } } m.trackOutbound(d, srcIP, dstIP, packetData, size) @@ -691,6 +720,97 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } +// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation. +// Both sides advertise their MSS during connection establishment, so we need to clamp both. +func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { + if !d.tcp.SYN { + return false + } + if len(d.tcp.Options) == 0 { + return false + } + + mssOptionIndex := -1 + var currentMSS uint16 + for i, opt := range d.tcp.Options { + if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { + currentMSS = binary.BigEndian.Uint16(opt.OptionData) + if currentMSS > m.mssClampValue { + mssOptionIndex = i + break + } + } + } + + if mssOptionIndex == -1 { + return false + } + + ipHeaderSize := int(d.ip4.IHL) * 4 + if ipHeaderSize < 20 { + return false + } + + if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { + return false + } + + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + return true +} + +func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { + tcpHeaderStart := ipHeaderSize + tcpOptionsStart := tcpHeaderStart + 20 + + optOffset := tcpOptionsStart + for j := 0; j < mssOptionIndex; j++ { + switch d.tcp.Options[j].OptionType { + case layers.TCPOptionKindEndList, layers.TCPOptionKindNop: + optOffset++ + default: + optOffset += 2 + len(d.tcp.Options[j].OptionData) + } + } + + mssValueOffset := optOffset + 2 + binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) + + m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) + return true +} + +func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) { + tcpLayer := packetData[tcpHeaderStart:] + tcpLength := len(packetData) - tcpHeaderStart + + tcpLayer[16] = 0 + tcpLayer[17] = 0 + + var pseudoSum uint32 + pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) + pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) + pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) + pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) + pseudoSum += uint32(d.ip4.Protocol) + pseudoSum += uint32(tcpLength) + + var sum uint32 = pseudoSum + for i := 0; i < tcpLength-1; i += 2 { + sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) + } + if tcpLength%2 == 1 { + sum += uint32(tcpLayer[tcpLength-1]) << 8 + } + + for sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16) + } + + checksum := ^uint16(sum) + binary.BigEndian.PutUint16(tcpLayer[16:18], checksum) +} + func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 0cffcc1a7..5a2d0410f 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -17,6 +17,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { } } -// generateTCPPacketWithFlags creates a TCP packet with specific flags -func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { - b.Helper() - - ipv4 := &layers.IPv4{ - TTL: 64, - Version: 4, - SrcIP: srcIP, - DstIP: dstIP, - Protocol: layers.IPProtocolTCP, - } - - tcp := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - } - - // Set TCP flags - tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 - tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 - tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 - tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 - tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 - - require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) - return buf.Bytes() -} - func BenchmarkRouteACLs(b *testing.B) { manager := setupRoutedManager(b, "10.10.0.100/16") @@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) { } } } + +// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound. +// This shows the overhead difference between the common case (non-SYN packets, fast path) +// and the rare case (SYN packets that need clamping, expensive path). +func BenchmarkMSSClamping(b *testing.B) { + scenarios := []struct { + name string + description string + genPacket func(*testing.B, net.IP, net.IP) []byte + frequency string + }{ + { + name: "syn_needs_clamp", + description: "SYN packet needing MSS clamping", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + frequency: "~0.1% of traffic - EXPENSIVE", + }, + { + name: "syn_no_clamp_needed", + description: "SYN packet with already-small MSS", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200) + }, + frequency: "~0.05% of traffic", + }, + { + name: "tcp_ack", + description: "Non-SYN TCP packet (ACK, data transfer)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + frequency: "~60-70% of traffic - FAST PATH", + }, + { + name: "tcp_psh_ack", + description: "TCP data packet (PSH+ACK)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + frequency: "~10-20% of traffic - FAST PATH", + }, + { + name: "udp", + description: "UDP packet", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + frequency: "~20-30% of traffic - FAST PATH", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled +// for the common case (non-SYN TCP packets). +func BenchmarkMSSClampingOverhead(b *testing.B) { + scenarios := []struct { + name string + enabled bool + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "disabled_tcp_ack", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "enabled_tcp_ack", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "disabled_syn_needs_clamp", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "enabled_syn_needs_clamp", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = sc.enabled + if sc.enabled { + manager.mssClampValue = 1240 + } + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases +func BenchmarkMSSClampingMemory(b *testing.B) { + scenarios := []struct { + name string + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "tcp_ack_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "syn_needs_clamp", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "udp_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte { + b.Helper() + + ip := &layers.IPv4{ + Version: 4, + IHL: 5, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Seq: 1000, + Window: 65535, + } + + require.NoError(b, tcp.SetNetworkLayerForChecksum(ip)) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{}))) + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index 73f3face8..eb5aa3343 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -12,6 +12,7 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" @@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) require.NotNil(t, manager) @@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(tb, err) require.NoError(tb, manager.EnableRouting()) require.NotNil(tb, manager) @@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index bac06814d..c56a078fc 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "fmt" "net" "net/netip" @@ -17,6 +18,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nbiface "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" @@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) { } // creating manager instance - manager, err := Create(iface, false, flowLogger) + manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Fatalf("Failed to create Manager: %s", err) } @@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() @@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) time.Sleep(time.Second) @@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() // Close the existing tracker @@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) { require.Equal(t, tc.expected, isAllowed, tc.desc) } } + +func TestMSSClamping(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.10.0.100"), + Network: netip.MustParsePrefix("100.10.0.0/16"), + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, 1280) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") + expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) + require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + srcIP := net.ParseIP("100.10.0.2") + dstIP := net.ParseIP("8.8.8.8") + + t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") + }) + + t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { + lowMSS := uint16(1200) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified") + }) + + t.Run("SYN-ACK packet gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") + }) + + t.Run("Non-SYN packet unchanged", func(t *testing.T) { + packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck)) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Empty(t, d.tcp.Options, "ACK packet should have no options") + }) +} + +func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + ACK: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + Window: 65535, + } + + if flags&uint16(conntrack.TCPSyn) != 0 { + tcpLayer.SYN = true + } + if flags&uint16(conntrack.TCPAck) != 0 { + tcpLayer.ACK = true + } + if flags&uint16(conntrack.TCPFin) != 0 { + tcpLayer.FIN = true + } + if flags&uint16(conntrack.TCPRst) != 0 { + tcpLayer.RST = true + } + if flags&uint16(conntrack.TCPPush) != 0 { + tcpLayer.PSH = true + } + + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 42a3e0800..00cb3f1df 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -45,7 +45,7 @@ type Forwarder struct { netstack bool } -func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) { +func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{ @@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow HandleLocal: false, }) - mtu, err := iface.GetDevice().MTU() - if err != nil { - return nil, fmt.Errorf("get MTU: %w", err) - } nicID := tcpip.NICID(1) endpoint := &endpoint{ logger: logger, @@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow } if err := s.CreateNIC(nicID, endpoint); err != nil { - return nil, fmt.Errorf("failed to create NIC: %v", err) + return nil, fmt.Errorf("create NIC: %v", err) } protoAddr := tcpip.ProtocolAddress{ diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index d146de5e4..55743d975 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -49,7 +49,7 @@ type idleConn struct { conn *udpPacketConn } -func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { +func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { ctx, cancel := context.WithCancel(context.Background()) f := &udpForwarder{ logger: logger, diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index d726474cf..d2599e577 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) { func BenchmarkDNATConcurrency(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) { b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) { func BenchmarkDNATMemoryAllocations(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -472,7 +473,7 @@ func BenchmarkPortDNAT(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 2a285484c..400d61020 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -16,7 +17,7 @@ import ( func TestDNATTranslationCorrectness(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -100,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder { func TestDNATMappingManagement(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -148,7 +149,7 @@ func TestDNATMappingManagement(t *testing.T) { func TestInboundPortDNAT(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -198,7 +199,7 @@ func TestInboundPortDNAT(t *testing.T) { func TestInboundPortDNATNegative(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index ee1bb8a23..d9f9f1aa8 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -10,6 +10,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) if !statefulMode { diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index daf4979ce..638245bf7 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/netflow" @@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) @@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) @@ -321,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 11575d500..451b83f92 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - pf, err := uspfilter.Create(wgIface, false, flowLogger) + pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU) if err != nil { t.Fatalf("failed to create uspfilter: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index ad69bcf43..0c7bd9f0a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -506,7 +506,7 @@ func (e *Engine) createFirewall() error { } var err error - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) if err != nil || e.firewall == nil { log.Errorf("failed creating firewall manager: %s", err) return nil From 641eb5140bf51a5b6f2ca1b3a391662bffb86e32 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 4 Nov 2025 21:56:53 +0100 Subject: [PATCH 056/120] [client] Allow INPUT traffic on the compat iptables filter table for nftables (#4742) --- client/firewall/manager/firewall.go | 3 + client/firewall/nftables/acl_linux.go | 41 +++---- client/firewall/nftables/manager_linux.go | 141 ++++------------------ client/firewall/nftables/router_linux.go | 94 +++++++++++---- client/internal/dnsfwd/manager.go | 42 ++++++- client/internal/engine.go | 35 +----- 6 files changed, 146 insertions(+), 210 deletions(-) diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 7ee33118b..72e6a5c68 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -100,6 +100,9 @@ type Manager interface { // // If comment argument is empty firewall manager should set // rule ID as comment for the rule + // + // Note: Callers should call Flush() after adding rules to ensure + // they are applied to the kernel and rule handles are refreshed. AddPeerFiltering( id []byte, ip net.IP, diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 9ff5b8c92..a9d066e2f 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -29,8 +29,6 @@ const ( chainNameForwardFilter = "netbird-acl-forward-filter" chainNameManglePrerouting = "netbird-mangle-prerouting" chainNameManglePostrouting = "netbird-mangle-postrouting" - - allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) const flushError = "flush: %w" @@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { // createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - // mask - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: []byte{0, 0, 0, 0}, - Xor: []byte{0, 0, 0, 0}, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, &expr.Verdict{ Kind: expr.VerdictAccept, }, @@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering( action firewall.Action, ipset *nftables.Set, ) (*Rule, error) { - ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) + ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ nftRule: r.nftRule, @@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering( } if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf(flushError, err) + return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err) } ruleStruct := &Rule{ - nftRule: nftRule, + nftRule: nftRule, + // best effort mangle rule mangleRule: m.createPreroutingRule(expressions, userData), nftSet: ipset, ruleID: ruleId, @@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt }, ) - return m.rConn.AddRule(&nftables.Rule{ + nfRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: m.chainPrerouting, Exprs: preroutingExprs, UserData: userData, }) + + if err := m.rConn.Flush(); err != nil { + log.Errorf("failed to flush mangle rule %s: %v", string(userData), err) + return nil + } + + return nfRule } func (m *AclManager) createDefaultChains() (err error) { @@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro return nil } -func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { - rulesetID := ":" +func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { + rulesetID := ":" + string(proto) + ":" if sPort != nil { rulesetID += sPort.String() } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index d864914fe..bd19f1067 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -1,11 +1,11 @@ package nftables import ( - "bytes" "context" "fmt" "net" "net/netip" + "os" "sync" "github.com/google/nftables" @@ -19,13 +19,22 @@ import ( ) const ( - // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + // tableNameNetbird is the default name of the table that is used for filtering by the Netbird client tableNameNetbird = "netbird" + // envTableName is the environment variable to override the table name + envTableName = "NB_NFTABLES_TABLE" tableNameFilter = "filter" chainNameInput = "INPUT" ) +func getTableName() string { + if name := os.Getenv(envTableName); name != "" { + return name + } + return tableNameNetbird +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string @@ -50,7 +59,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { wgIface: wgIface, } - workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} + workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} var err error m.router, err = newRouter(workTable, wgIface, mtu) @@ -198,44 +207,11 @@ func (m *Manager) AllowNetbird() error { m.mutex.Lock() defer m.mutex.Unlock() - err := m.aclManager.createDefaultAllowRules() - if err != nil { - return fmt.Errorf("failed to create default allow rules: %v", err) + if err := m.aclManager.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create default allow rules: %w", err) } - - chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list of chains: %w", err) - } - - var chain *nftables.Chain - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - chain = c - break - } - } - - if chain == nil { - log.Debugf("chain INPUT not found. Skipping add allow netbird rule") - return nil - } - - rules, err := m.rConn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("failed to get rules for the INPUT chain: %v", err) - } - - if rule := m.detectAllowNetbirdRule(rules); rule != nil { - log.Debugf("allow netbird rule already exists: %v", rule) - return nil - } - - m.applyAllowNetbirdRules(chain) - - err = m.rConn.Flush() - if err != nil { - return fmt.Errorf("failed to flush allow input netbird rules: %v", err) + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush allow input netbird rules: %w", err) } return nil @@ -251,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - if err := m.resetNetbirdInputRules(); err != nil { - return fmt.Errorf("reset netbird input rules: %v", err) - } - if err := m.router.Reset(); err != nil { return fmt.Errorf("reset router: %v", err) } @@ -274,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { return nil } -func (m *Manager) resetNetbirdInputRules() error { - chains, err := m.rConn.ListChains() - if err != nil { - return fmt.Errorf("list chains: %w", err) - } - - m.deleteNetbirdInputRules(chains) - - return nil -} - -func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - rules, err := m.rConn.GetRules(c.Table, c) - if err != nil { - log.Errorf("get rules for chain %q: %v", c.Name, err) - continue - } - - m.deleteMatchingRules(rules) - } - } -} - -func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } - } -} - func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } @@ -399,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { return nil, fmt.Errorf("list of tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } -func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { - rule := &nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, - UserData: []byte(allowNetbirdInputRuleID), - } - _ = m.rConn.InsertRule(rule) -} - -func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { - ifName := ifname(m.wgIface.Name()) - for _, rule := range existedRules { - if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { - if len(rule.Exprs) < 4 { - if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { - continue - } - if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) { - continue - } - return rule - } - } - } - return nil -} - func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { rule := &nftables.Rule{ Table: table, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 0a2c79186..6192c92aa 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -37,6 +37,7 @@ const ( userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" + userDataAcceptInputRule = "inputaccept" dnatSuffix = "_dnat" snatSuffix = "_snat" @@ -103,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou func (r *router) init(workTable *nftables.Table) error { r.workTable = workTable - if err := r.removeAcceptForwardRules(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + if err := r.removeAcceptFilterRules(); err != nil { + log.Errorf("failed to clean up rules from filter table: %s", err) } if err := r.createContainers(); err != nil { @@ -118,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error { return nil } -// Reset cleans existing nftables default forward rules from the system +// Reset cleans existing nftables filter table rules from the system func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() var merr *multierror.Error - if err := r.removeAcceptForwardRules(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err)) + if err := r.removeAcceptFilterRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } if err := r.removeNatPreroutingRules(); err != nil { @@ -936,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +// This method also adds INPUT chain rules to allow traffic to the local interface. func (r *router) acceptForwardRules() error { if r.filterTable == nil { log.Debugf("table 'filter' not found for forward rules, skipping accept rules") @@ -945,7 +947,7 @@ func (r *router) acceptForwardRules() error { fw := "iptables" defer func() { - log.Debugf("Used %s to add accept forward rules", fw) + log.Debugf("Used %s to add accept forward and input rules", fw) }() // Try iptables first and fallback to nftables if iptables is not available @@ -955,22 +957,30 @@ func (r *router) acceptForwardRules() error { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" - return r.acceptForwardRulesNftables() + return r.acceptFilterRulesNftables() } - return r.acceptForwardRulesIptables(ipt) + return r.acceptFilterRulesIptables(ipt) } -func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) } else { - log.Debugf("added iptables rule: %v", rule) + log.Debugf("added iptables forward rule: %v", rule) } } + inputRule := r.getAcceptInputRule() + if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) + } else { + log.Debugf("added iptables input rule: %v", inputRule) + } + return nberrors.FormatErrorOrNil(merr) } @@ -982,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string { } } -func (r *router) acceptForwardRulesNftables() error { +func (r *router) getAcceptInputRule() []string { + return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} +} + +func (r *router) acceptFilterRulesNftables() error { intf := ifname(r.wgIface.Name()) - // Rule for incoming interface (iif) with counter iifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ @@ -1018,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error { }, } - // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ - Name: "FORWARD", + Name: chainNameForward, Table: r.filterTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookForward, @@ -1031,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error { Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } - r.conn.InsertRule(oifRule) + inputRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: chainNameInput, + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptInputRule), + } + r.conn.InsertRule(inputRule) + return nil } -func (r *router) removeAcceptForwardRules() error { +func (r *router) removeAcceptFilterRules() error { if r.filterTable == nil { return nil } - // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) - return r.removeAcceptForwardRulesNftables() + return r.removeAcceptFilterRulesNftables() } - return r.removeAcceptForwardRulesIptables(ipt) + return r.removeAcceptFilterRulesIptables(ipt) } -func (r *router) removeAcceptForwardRulesNftables() error { +func (r *router) removeAcceptFilterRulesNftables() error { chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) if err != nil { return fmt.Errorf("list chains: %v", err) } for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + if chain.Table.Name != r.filterTable.Name { + continue + } + + if chain.Name != chainNameForward && chain.Name != chainNameInput { continue } @@ -1070,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error { for _, rule := range rules { if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { if err := r.conn.DelRule(rule); err != nil { return fmt.Errorf("delete rule: %v", err) } @@ -1085,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error { return nil } -func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) } } + inputRule := r.getAcceptInputRule() + if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index b26836d17..58b88d9ef 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -15,6 +15,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -134,6 +135,8 @@ func (m *Manager) Stop(ctx context.Context) error { } } + m.unregisterNetstackServices() + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -158,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error { dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add udp firewall rule: %w", err) } - m.fwRules = dnsRules tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add tcp firewall rule: %w", err) } + + if err := m.firewall.Flush(); err != nil { + return fmt.Errorf("flush: %w", err) + } + + m.fwRules = dnsRules m.tcpRules = tcpRules + m.registerNetstackServices() + return nil } +func (m *Manager) registerNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.TCP, m.serverPort) + registrar.RegisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + +func (m *Manager) unregisterNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort) + registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + func (m *Manager) dropDNSFirewall() error { var mErr *multierror.Error for _, rule := range m.fwRules { diff --git a/client/internal/engine.go b/client/internal/engine.go index 0c7bd9f0a..3c7d52cb3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -298,17 +298,12 @@ func (e *Engine) Stop() error { e.ingressGatewayMgr = nil } + e.stopDNSForwarder() + if e.routeManager != nil { e.routeManager.Stop(e.stateManager) } - if e.dnsForwardMgr != nil { - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil - } - if e.srWatcher != nil { e.srWatcher.Close() } @@ -1873,7 +1868,6 @@ func (e *Engine) updateDNSForwarder( func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) { e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface) - e.registerDNSServices() if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) @@ -1893,34 +1887,9 @@ func (e *Engine) stopDNSForwarder() { log.Errorf("failed to stop DNS forward: %v", err) } - e.unregisterDNSServices() e.dnsForwardMgr = nil } -func (e *Engine) registerDNSServices() { - if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { - if registrar, ok := e.firewall.(interface { - RegisterNetstackService(protocol nftypes.Protocol, port uint16) - }); ok { - registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) - registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) - log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) - } - } -} - -func (e *Engine) unregisterDNSServices() { - if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { - if registrar, ok := e.firewall.(interface { - UnregisterNetstackService(protocol nftypes.Protocol, port uint16) - }); ok { - registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) - registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) - log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) - } - } -} - func (e *Engine) GetNet() (*netstack.Net, error) { e.syncMsgMux.Lock() intf := e.wgInterface From c92e6c1b5fd040248b4fd36a2e397daeeb21ba9f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:15:37 +0100 Subject: [PATCH 057/120] [client] Block on all subsystems on shutdown (#4709) --- client/internal/connect.go | 32 +++-- client/internal/dns/server.go | 9 +- client/internal/engine.go | 119 +++++++++++++----- client/internal/netflow/manager.go | 17 ++- client/internal/peer/guard/sr_watcher.go | 9 +- client/internal/routemanager/manager.go | 14 ++- .../internal/routeselector/routeselector.go | 4 - 7 files changed, 139 insertions(+), 65 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index c9331baf5..bb7c2b38b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + nbnet "github.com/netbirdio/netbird/client/net" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -34,7 +35,6 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/version" ) @@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } <-engineCtx.Done() + c.engineMutex.Lock() - if c.engine != nil && c.engine.wgInterface != nil { - log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) - if err := c.engine.Stop(); err != nil { + engine := c.engine + c.engine = nil + c.engineMutex.Unlock() + + if engine != nil && engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) + if err := engine.Stop(); err != nil { log.Errorf("Failed to stop engine: %v", err) } - c.engine = nil } - c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() @@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType { } func (c *ConnectClient) Stop() error { - if c == nil { - return nil + engine := c.Engine() + if engine != nil { + if err := engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } } - c.engineMutex.Lock() - defer c.engineMutex.Unlock() - - if c.engine == nil { - return nil - } - if err := c.engine.Stop(); err != nil { - return fmt.Errorf("stop engine: %w", err) - } - return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 8cb886203..afaf0579f 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface { // DefaultServer dns server object type DefaultServer struct { - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + shutdownWg sync.WaitGroup // disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running. // This is different from ServiceEnable=false from management which completely disables the DNS service. disableSys bool @@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr { // Stop stops the server func (s *DefaultServer) Stop() { s.ctxCancel() + s.shutdownWg.Wait() s.mux.Lock() defer s.mux.Unlock() @@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.applyHostConfig() + s.shutdownWg.Add(1) go func() { - // persist dns state right away + defer s.shutdownWg.Done() if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 3c7d52cb3..ebc05c453 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -148,6 +148,8 @@ type Engine struct { // syncMsgMux is used to guarantee sequential Management Service message processing syncMsgMux *sync.Mutex + // sshMux protects sshServer field access + sshMux sync.Mutex config *EngineConfig mobileDep MobileDependency @@ -200,8 +202,10 @@ type Engine struct { flowManager nftypes.FlowManager // WireGuard interface monitor - wgIfaceMonitor *WGIfaceMonitor - wgIfaceMonitorWg sync.WaitGroup + wgIfaceMonitor *WGIfaceMonitor + + // shutdownWg tracks all long-running goroutines to ensure clean shutdown + shutdownWg sync.WaitGroup probeStunTurn *relay.StunTurnProbe } @@ -320,10 +324,6 @@ func (e *Engine) Stop() error { e.cancel() } - // very ugly but we want to remove peers from the WireGuard interface first before removing interface. - // Removing peers happens in the conn.Close() asynchronously - time.Sleep(500 * time.Millisecond) - e.close() // stop flow manager after wg interface is gone @@ -331,8 +331,6 @@ func (e *Engine) Stop() error { e.flowManager.Close() } - log.Infof("stopped Netbird Engine") - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -343,12 +341,52 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } - // Stop WireGuard interface monitor and wait for it to exit - e.wgIfaceMonitorWg.Wait() + timeout := e.calculateShutdownTimeout() + log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) + shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil { + log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout) + } + + log.Infof("stopped Netbird Engine") return nil } +// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s. +func (e *Engine) calculateShutdownTimeout() time.Duration { + peerCount := len(e.peerStore.PeersPubKey()) + + baseTimeout := 10 * time.Second + perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond + timeout := baseTimeout + perPeerTimeout + + maxTimeout := 30 * time.Second + if timeout > maxTimeout { + timeout = maxTimeout + } + + return timeout +} + +// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout. +func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service @@ -478,14 +516,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // monitor WireGuard interface lifecycle and restart engine on changes e.wgIfaceMonitor = NewWGIfaceMonitor() - e.wgIfaceMonitorWg.Add(1) + e.shutdownWg.Add(1) go func() { - defer e.wgIfaceMonitorWg.Done() + defer e.shutdownWg.Done() if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { log.Infof("WireGuard interface monitor: %s, restarting engine", err) - e.restartEngine() + e.triggerClientRestart() } else if err != nil { log.Warnf("WireGuard interface monitor: %s", err) } @@ -669,9 +707,11 @@ func (e *Engine) removeAllPeers() error { func (e *Engine) removePeer(peerKey string) error { log.Debugf("removing peer from engine %s", peerKey) + e.sshMux.Lock() if !isNil(e.sshServer) { e.sshServer.RemoveAuthorizedKey(peerKey) } + e.sshMux.Unlock() e.connMgr.RemovePeerConn(peerKey) @@ -873,6 +913,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { log.Warnf("running SSH server on %s is not supported", runtime.GOOS) return nil } + e.sshMux.Lock() // start SSH server if it wasn't running if isNil(e.sshServer) { listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort) @@ -880,34 +921,42 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort) } // nil sshServer means it has not yet been started - var err error - e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr) - + server, err := e.sshServerFunc(e.config.SSHKey, listenAddr) if err != nil { + e.sshMux.Unlock() return fmt.Errorf("create ssh server: %w", err) } + + e.sshServer = server + e.sshMux.Unlock() + go func() { // blocking - err = e.sshServer.Start() + err = server.Start() if err != nil { // will throw error when we stop it even if it is a graceful stop log.Debugf("stopped SSH server with error %v", err) } - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() + e.sshMux.Lock() e.sshServer = nil + e.sshMux.Unlock() log.Infof("stopped SSH server") }() } else { + e.sshMux.Unlock() log.Debugf("SSH server is already running") } - } else if !isNil(e.sshServer) { - // Disable SSH server request, so stop it if it was running - err := e.sshServer.Stop() - if err != nil { - log.Warnf("failed to stop SSH server %v", err) + } else { + e.sshMux.Lock() + if !isNil(e.sshServer) { + // Disable SSH server request, so stop it if it was running + err := e.sshServer.Stop() + if err != nil { + log.Warnf("failed to stop SSH server %v", err) + } + e.sshServer = nil } - e.sshServer = nil + e.sshMux.Unlock() } return nil } @@ -944,7 +993,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() info, err := system.GetInfoWithChecks(e.ctx, e.checks) if err != nil { log.Warnf("failed to get system info with checks: %v", err) @@ -1120,6 +1171,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.statusRecorder.FinishPeerListModifications() // update SSHServer by adding remote peer SSH keys + e.sshMux.Lock() if !isNil(e.sshServer) { for _, config := range networkMap.GetRemotePeers() { if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil { @@ -1130,6 +1182,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } } + e.sshMux.Unlock() } // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store @@ -1372,7 +1425,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers func (e *Engine) receiveSignalEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() // connect to a stream of messages coming from the signal server err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { e.syncMsgMux.Lock() @@ -1489,12 +1544,14 @@ func (e *Engine) close() { e.statusRecorder.SetWgIface(nil) } + e.sshMux.Lock() if !isNil(e.sshServer) { err := e.sshServer.Stop() if err != nil { log.Warnf("failed stopping the SSH server: %v", err) } } + e.sshMux.Unlock() if e.firewall != nil { err := e.firewall.Close(e.stateManager) @@ -1725,8 +1782,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool { return allHealthy } -// restartEngine restarts the engine by cancelling the client context -func (e *Engine) restartEngine() { +// triggerClientRestart triggers a full client restart by cancelling the client context. +// Note: This does NOT just restart the engine - it cancels the entire client context, +// which causes the connect client's retry loop to create a completely new engine. +func (e *Engine) triggerClientRestart() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -1748,7 +1807,9 @@ func (e *Engine) startNetworkMonitor() { } e.networkMonitor = networkmonitor.New() + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() if err := e.networkMonitor.Listen(e.ctx); err != nil { if errors.Is(err, context.Canceled) { log.Infof("network monitor stopped") @@ -1758,8 +1819,8 @@ func (e *Engine) startNetworkMonitor() { return } - log.Infof("Network monitor: detected network change, restarting engine") - e.restartEngine() + log.Infof("Network monitor: detected network change, triggering client restart") + e.triggerClientRestart() }() } diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index e3b188468..7752c97b0 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -24,6 +24,7 @@ import ( // Manager handles netflow tracking and logging type Manager struct { mux sync.Mutex + shutdownWg sync.WaitGroup logger nftypes.FlowLogger flowConfig *nftypes.FlowConfig conntrack nftypes.ConnTracker @@ -105,8 +106,15 @@ func (m *Manager) resetClient() error { ctx, cancel := context.WithCancel(context.Background()) m.cancel = cancel - go m.receiveACKs(ctx, flowClient) - go m.startSender(ctx) + m.shutdownWg.Add(2) + go func() { + defer m.shutdownWg.Done() + m.receiveACKs(ctx, flowClient) + }() + go func() { + defer m.shutdownWg.Done() + m.startSender(ctx) + }() return nil } @@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error { // Close cleans up all resources func (m *Manager) Close() { m.mux.Lock() - defer m.mux.Unlock() - if err := m.disableFlow(); err != nil { log.Warnf("failed to disable flow manager: %v", err) } + m.mux.Unlock() + + m.shutdownWg.Wait() } // GetLogger returns the flow logger diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 686430752..6f4f5ad4f 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -19,11 +19,10 @@ type SRWatcher struct { signalClient chNotifier relayManager chNotifier - listeners map[chan struct{}]struct{} - mu sync.Mutex - iFaceDiscover stdnet.ExternalIFaceDiscover - iceConfig ice.Config - + listeners map[chan struct{}]struct{} + mu sync.Mutex + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config cancelIceMonitor context.CancelFunc } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 37974cd17..26cf758d9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -81,6 +81,7 @@ type DefaultManager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex + shutdownWg sync.WaitGroup clientNetworks map[route.HAUniqueID]*client.Watcher routeSelector *routeselector.RouteSelector serverRouter *server.Router @@ -283,6 +284,7 @@ func (m *DefaultManager) SetDNSForwarderPort(port uint16) { // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() + m.shutdownWg.Wait() if m.serverRouter != nil { m.serverRouter.CleanUp() } @@ -485,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { } clientNetworkWatcher := client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes}) } @@ -527,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout } clientNetworkWatcher = client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() } update := client.RoutesUpdate{ UpdateSerial: updateSerial, diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index e4a78599e..61c8bbc79 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -9,8 +9,6 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" ) @@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { defer rs.mu.RUnlock() if rs.deselectAll { - log.Debugf("Route %s not selected (deselect all)", routeID) return false } _, deselected := rs.deselectedRoutes[routeID] isSelected := !deselected - log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected) return isSelected } From 75327d951941f60229565a38ece91d4ceae97303 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 5 Nov 2025 17:00:20 +0100 Subject: [PATCH 058/120] [client] Add login_hint to oidc flows (#4724) --- client/android/login.go | 2 +- client/cmd/login.go | 26 +++++++++++++++++++++----- client/cmd/up.go | 9 ++++++++- client/internal/auth/device_flow.go | 25 +++++++++++++++++++++++++ client/internal/auth/oauth.go | 18 +++++++++++------- client/internal/auth/pkce_flow.go | 3 +++ client/internal/device_auth.go | 2 ++ client/internal/pkce_auth.go | 2 ++ client/ios/NetBirdSDK/client.go | 2 +- client/proto/daemon.pb.go | 21 ++++++++++++++++----- client/proto/daemon.proto | 3 +++ client/server/server.go | 6 +++++- client/ui/client_ui.go | 13 +++++++++++-- 13 files changed, 109 insertions(+), 23 deletions(-) diff --git a/client/android/login.go b/client/android/login.go index 0df78dbc3..16df24ba8 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error { } func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false) + oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "") if err != nil { return nil, err } diff --git a/client/cmd/login.go b/client/cmd/login.go index 40b55f858..b0c877faa 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str Username: &username, } + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { loginRequest.OptionalPreSharedKey = &preSharedKey } @@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, return fmt.Errorf("read config file %s: %v", configFilePath, err) } - err = foregroundLogin(ctx, cmd, config, setupKey) + err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo return nil } -func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error { +func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { needsLogin := false err := WithBackOff(func() error { @@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman jwtToken := "" if setupKey == "" && needsLogin { - tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config) + tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } @@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman return nil } -func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) +func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) { + hint := "" + pm := profilemanager.NewProfileManager() + profileState, err := pm.GetProfileState(profileName) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + hint = profileState.Email + } + + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint) if err != nil { return nil, err } diff --git a/client/cmd/up.go b/client/cmd/up.go index d047c041e..80175f7be 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) - err = foregroundLogin(ctx, cmd, config, providedSetupKey) + err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ loginRequest.ProfileName = &activeProf.Name loginRequest.Username = &username + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + var loginErr error var loginResp *proto.LoginResponse diff --git a/client/internal/auth/device_flow.go b/client/internal/auth/device_flow.go index da4f16c8d..8ca760742 100644 --- a/client/internal/auth/device_flow.go +++ b/client/internal/auth/device_flow.go @@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow deviceCode.VerificationURIComplete = deviceCode.VerificationURI } + if d.providerConfig.LoginHint != "" { + deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint) + if deviceCode.VerificationURI != "" { + deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint) + } + } + return deviceCode, err } +func appendLoginHint(uri, loginHint string) string { + if uri == "" || loginHint == "" { + return uri + } + + parsedURL, err := url.Parse(uri) + if err != nil { + log.Debugf("failed to parse verification URI for login_hint: %v", err) + return uri + } + + query := parsedURL.Query() + query.Set("login_hint", loginHint) + parsedURL.RawQuery = query.Encode() + + return parsedURL.String() +} + func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) { form := url.Values{} form.Add("client_id", d.providerConfig.ClientID) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 4458f600c..9fbd6cf5f 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string { // and if that also fails, the authentication process is deemed unsuccessful // // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow -func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) { +func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { - return authenticateWithDeviceCodeFlow(ctx, config) + return authenticateWithDeviceCodeFlow(ctx, config, hint) } - pkceFlow, err := authenticateWithPKCEFlow(ctx, config) + pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint) if err != nil { - // fallback to device code flow log.Debugf("failed to initialize pkce authentication with error: %v\n", err) log.Debug("falling back to device code flow") - return authenticateWithDeviceCodeFlow(ctx, config) + return authenticateWithDeviceCodeFlow(ctx, config, hint) } return pkceFlow, nil } // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow -func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) } + + pkceFlowInfo.ProviderConfig.LoginHint = hint + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) } // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow -func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { switch s, ok := gstatus.FromError(err); { @@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager. } } + deviceFlowInfo.ProviderConfig.LoginHint = hint + return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 8741e8636..738d3e34f 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn params = append(params, oauth2.SetAuthURLParam("max_age", "0")) } } + if p.providerConfig.LoginHint != "" { + params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint)) + } authURL := p.oAuthConfig.AuthCodeURL(state, params...) diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go index 6bd29801d..7f7d06130 100644 --- a/client/internal/device_auth.go +++ b/client/internal/device_auth.go @@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct { Scope string // UseIDToken indicates if the id token should be used for authentication UseIDToken bool + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go index a713bb342..23c92e8af 100644 --- a/client/internal/pkce_auth.go +++ b/client/internal/pkce_auth.go @@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct { DisablePromptLogin bool // LoginFlag is used to configure the PKCE flow login behavior LoginFlag common.LoginFlag + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 2109d4b15..fa1c89aab 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string { ConfigPath: c.cfgFile, }) - oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false) + oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "") if err != nil { return err.Error() } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 841e3c0f7..02f09b08a 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -279,8 +279,10 @@ type LoginRequest struct { ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // hint is used to pre-fill the email/username field during SSO authentication + Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *LoginRequest) Reset() { @@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 { return 0 } +func (x *LoginRequest) GetHint() string { + if x != nil && x.Hint != nil { + return *x.Hint + } + return "" +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xc3\x0e\n" + + "\fEmptyRequest\"\xe5\x0e\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" + + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" + + "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" + "\x0e_block_inboundB\x0e\n" + "\f_profileNameB\v\n" + "\t_usernameB\x06\n" + - "\x04_mtu\"\xb5\x01\n" + + "\x04_mtuB\a\n" + + "\x05_hint\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 5b27b4d98..8d1080051 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -158,6 +158,9 @@ message LoginRequest { optional string username = 31; optional int64 mtu = 32; + + // hint is used to pre-fill the email/username field during SSO authentication + optional string hint = 33; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index 3641e6f92..6699cdadc 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro state.Set(internal.StatusConnecting) if msg.SetupKey == "" { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient) + hint := "" + if msg.Hint != nil { + hint = *msg.Hint + } + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint) if err != nil { state.Set(internal.StatusLoginFailed) return nil, err diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index f4350b251..e580be56d 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -610,11 +610,20 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe return nil, fmt.Errorf("get current user: %w", err) } - loginResp, err := conn.Login(ctx, &proto.LoginRequest{ + loginReq := &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", ProfileName: &activeProf.Name, Username: &currUser.Username, - }) + } + + profileState, err := s.profileManager.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginReq.Hint = &profileState.Email + } + + loginResp, err := conn.Login(ctx, loginReq) if err != nil { return nil, fmt.Errorf("login to management: %w", err) } From 229e0038ee643ec775d0670c030db70670ae298c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 5 Nov 2025 17:30:17 +0100 Subject: [PATCH 059/120] [client] Add dns config to debug bundle (#4704) --- client/internal/debug/debug.go | 20 +++++++++ client/internal/debug/debug_darwin.go | 53 ++++++++++++++++++++++++ client/internal/debug/debug_mobile.go | 4 ++ client/internal/debug/debug_nondarwin.go | 16 +++++++ client/internal/debug/debug_nonunix.go | 7 ++++ client/internal/debug/debug_unix.go | 29 +++++++++++++ 6 files changed, 129 insertions(+) create mode 100644 client/internal/debug/debug_darwin.go create mode 100644 client/internal/debug/debug_nondarwin.go create mode 100644 client/internal/debug/debug_nonunix.go create mode 100644 client/internal/debug/debug_unix.go diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 442f54e71..fbec29ce3 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -44,6 +44,8 @@ interfaces.txt: Anonymized network interface information, if --system-info flag ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. +resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided. +scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. @@ -184,6 +186,20 @@ The ip_rules.txt file contains detailed IP routing rule information: The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing. For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged. + +DNS Configuration +The debug bundle includes platform-specific DNS configuration files: + +resolv.conf (Unix systems): +- Contains DNS resolver configuration from /etc/resolv.conf +- Includes nameserver entries, search domains, and resolver options +- All IP addresses and domain names are anonymized following the same rules as other files + +scutil_dns.txt (macOS only): +- Contains detailed DNS configuration from scutil --dns +- Shows DNS configuration for all network interfaces +- Includes search domains, nameservers, and DNS resolver settings +- All IP addresses and domain names are anonymized ` const ( @@ -357,6 +373,10 @@ func (g *BundleGenerator) addSystemInfo() { if err := g.addFirewallRules(); err != nil { log.Errorf("failed to add firewall rules to debug bundle: %v", err) } + + if err := g.addDNSInfo(); err != nil { + log.Errorf("failed to add DNS info to debug bundle: %v", err) + } } func (g *BundleGenerator) addReadme() error { diff --git a/client/internal/debug/debug_darwin.go b/client/internal/debug/debug_darwin.go new file mode 100644 index 000000000..91e10214f --- /dev/null +++ b/client/internal/debug/debug_darwin.go @@ -0,0 +1,53 @@ +//go:build darwin && !ios + +package debug + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +// addDNSInfo collects and adds DNS configuration information to the archive +func (g *BundleGenerator) addDNSInfo() error { + if err := g.addResolvConf(); err != nil { + log.Errorf("failed to add resolv.conf: %v", err) + } + + if err := g.addScutilDNS(); err != nil { + log.Errorf("failed to add scutil DNS output: %v", err) + } + + return nil +} + +func (g *BundleGenerator) addScutilDNS() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "scutil", "--dns") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("execute scutil --dns: %w", err) + } + + if len(bytes.TrimSpace(output)) == 0 { + return fmt.Errorf("no scutil DNS output") + } + + content := string(output) + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + + if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil { + return fmt.Errorf("add scutil DNS output to zip: %w", err) + } + + return nil +} diff --git a/client/internal/debug/debug_mobile.go b/client/internal/debug/debug_mobile.go index c00c65132..3c1745ff3 100644 --- a/client/internal/debug/debug_mobile.go +++ b/client/internal/debug/debug_mobile.go @@ -5,3 +5,7 @@ package debug func (g *BundleGenerator) addRoutes() error { return nil } + +func (g *BundleGenerator) addDNSInfo() error { + return nil +} diff --git a/client/internal/debug/debug_nondarwin.go b/client/internal/debug/debug_nondarwin.go new file mode 100644 index 000000000..dfc2eace5 --- /dev/null +++ b/client/internal/debug/debug_nondarwin.go @@ -0,0 +1,16 @@ +//go:build unix && !darwin && !android + +package debug + +import ( + log "github.com/sirupsen/logrus" +) + +// addDNSInfo collects and adds DNS configuration information to the archive +func (g *BundleGenerator) addDNSInfo() error { + if err := g.addResolvConf(); err != nil { + log.Errorf("failed to add resolv.conf: %v", err) + } + + return nil +} diff --git a/client/internal/debug/debug_nonunix.go b/client/internal/debug/debug_nonunix.go new file mode 100644 index 000000000..18d017050 --- /dev/null +++ b/client/internal/debug/debug_nonunix.go @@ -0,0 +1,7 @@ +//go:build !unix + +package debug + +func (g *BundleGenerator) addDNSInfo() error { + return nil +} diff --git a/client/internal/debug/debug_unix.go b/client/internal/debug/debug_unix.go new file mode 100644 index 000000000..7e8a74eb0 --- /dev/null +++ b/client/internal/debug/debug_unix.go @@ -0,0 +1,29 @@ +//go:build unix && !android + +package debug + +import ( + "fmt" + "os" + "strings" +) + +const resolvConfPath = "/etc/resolv.conf" + +func (g *BundleGenerator) addResolvConf() error { + data, err := os.ReadFile(resolvConfPath) + if err != nil { + return fmt.Errorf("read %s: %w", resolvConfPath, err) + } + + content := string(data) + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + + if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil { + return fmt.Errorf("add resolv.conf to zip: %w", err) + } + + return nil +} From 5c29d395b29d3d152117c11e3a375111863a204f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 6 Nov 2025 12:51:14 +0100 Subject: [PATCH 060/120] [management] activity events on group updates (#4750) --- management/server/group.go | 20 +++++++++++++++----- management/server/user.go | 38 ++++++++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 487cb6d97..a29c28892 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -138,6 +138,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } + newGroup.AccountID = accountID + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID) if err != nil { return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) @@ -157,11 +162,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } } - newGroup.AccountID = accountID - - events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) - eventsToStore = append(eventsToStore, events...) - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) if err != nil { return err @@ -335,6 +335,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac if err == nil && oldGroup != nil { addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) + + if oldGroup.Name != newGroup.Name { + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "old_name": oldGroup.Name, + "new_name": newGroup.Name, + } + am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta) + }) + } } else { addedPeers = append(addedPeers, newGroup.Peers...) eventsToStore = append(eventsToStore, func() { diff --git a/management/server/user.go b/management/server/user.go index d40d33c6a..25c87df9c 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -595,7 +595,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { @@ -621,6 +621,35 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac }) } + addedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, addedGroupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to get added groups for user %s update event: %v", oldUser.Id, err) + } + + for _, group := range addedGroups { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, + } + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupAddedToUser, meta) + }) + } + + removedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, removedGroupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to get removed groups for user %s update event: %v", oldUser.Id, err) + } + for _, group := range removedGroups { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, + } + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) + }) + } + return eventsToStore } @@ -667,9 +696,10 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact peersToExpire = userPeers } + var removedGroups, addedGroups []string if update.AutoGroups != nil && settings.GroupsPropagationEnabled { - removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) - addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups) + removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups) + addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups) for _, peer := range userPeers { for _, groupID := range removedGroups { if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil { @@ -685,7 +715,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } updateAccountPeers := len(userPeers) > 0 - userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole) + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, removedGroups, addedGroups, transaction) return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil } From 2e16c9914a7ba92fa92f56c66c079e4dd1e0fd34 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 6 Nov 2025 19:01:44 +0300 Subject: [PATCH 061/120] [management] Bump github.com/containerd/containerd from 1.7.27 to 1.7.29 (#4756) Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.27 to 1.7.29. - [Release notes](https://github.com/containerd/containerd/releases) - [Changelog](https://github.com/containerd/containerd/blob/main/RELEASES.md) - [Commits](https://github.com/containerd/containerd/compare/v1.7.27...v1.7.29) --- updated-dependencies: - dependency-name: github.com/containerd/containerd dependency-version: 1.7.29 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 98b158411..d93d2c064 100644 --- a/go.mod +++ b/go.mod @@ -102,9 +102,9 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/mod v0.25.0 + golang.org/x/mod v0.26.0 golang.org/x/net v0.42.0 - golang.org/x/oauth2 v0.28.0 + golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 golang.org/x/term v0.33.0 google.golang.org/api v0.177.0 @@ -146,7 +146,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/containerd/containerd v1.7.27 // indirect + github.com/containerd/containerd v1.7.29 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect @@ -245,7 +245,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect - golang.org/x/time v0.5.0 // indirect + golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.34.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect diff --git a/go.sum b/go.sum index 9705f3aed..61ad8740e 100644 --- a/go.sum +++ b/go.sum @@ -142,8 +142,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= -github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII= -github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0= +github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE= +github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= @@ -818,8 +818,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= -golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -880,8 +880,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= -golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -993,8 +993,8 @@ golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= From 6aa4ba7af441c9c07fd76d7946e179ea5257e975 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 7 Nov 2025 10:44:46 +0100 Subject: [PATCH 062/120] [management] incremental network map builder (#4753) --- go.mod | 2 +- management/main.go | 10 +- management/server/account.go | 36 + management/server/account/manager.go | 1 + management/server/account_test.go | 55 + management/server/dns.go | 3 + management/server/group.go | 24 + management/server/grpcserver.go | 85 +- management/server/holder.go | 39 + management/server/mock_server/account_mock.go | 14 +- management/server/nameserver.go | 9 + management/server/networkmap.go | 80 + management/server/networks/manager.go | 3 + .../server/networks/resources/manager.go | 9 + management/server/networks/routers/manager.go | 9 + management/server/peer.go | 133 +- management/server/peer_test.go | 23 +- management/server/policy.go | 6 + management/server/posture_checks.go | 3 + management/server/route.go | 9 + management/server/settings/manager.go | 8 + management/server/store/sql_store.go | 1265 ++++++++++- .../store/sql_store_get_account_test.go | 1089 ++++++++++ .../server/store/sqlstore_bench_test.go | 951 ++++++++ management/server/store/store.go | 104 +- management/server/types/account.go | 13 + management/server/types/holder.go | 43 + management/server/types/networkmap.go | 58 + .../server/types/networkmap_golden_test.go | 1069 +++++++++ management/server/types/networkmapbuilder.go | 1932 +++++++++++++++++ management/server/updatechannel.go | 6 +- management/server/user.go | 4 + route/route.go | 1 + 33 files changed, 7018 insertions(+), 78 deletions(-) create mode 100644 management/server/holder.go create mode 100644 management/server/networkmap.go create mode 100644 management/server/store/sql_store_get_account_test.go create mode 100644 management/server/store/sqlstore_bench_test.go create mode 100644 management/server/types/holder.go create mode 100644 management/server/types/networkmap.go create mode 100644 management/server/types/networkmap_golden_test.go create mode 100644 management/server/types/networkmapbuilder.go diff --git a/go.mod b/go.mod index d93d2c064..68a12908d 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 github.com/mdlayher/socket v0.5.1 @@ -183,7 +184,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/management/main.go b/management/main.go index 561ed8f26..ff8482f97 100644 --- a/management/main.go +++ b/management/main.go @@ -1,11 +1,19 @@ package main import ( - "github.com/netbirdio/netbird/management/cmd" + "log" + "net/http" + // nolint:gosec + _ "net/http/pprof" "os" + + "github.com/netbirdio/netbird/management/cmd" ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() if err := cmd.Execute(); err != nil { os.Exit(1) } diff --git a/management/server/account.go b/management/server/account.go index dca105ddf..0aecbd586 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -53,6 +53,9 @@ const ( peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + + envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" + envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" ) type userLoggedInOnce bool @@ -109,6 +112,11 @@ type DefaultAccountManager struct { loginFilter *loginFilter disableDefaultPolicy bool + + holder *types.Holder + + expNewNetworkMap bool + expNewNetworkMapAIDs map[string]struct{} } func isUniqueConstraintError(err error) bool { @@ -196,6 +204,18 @@ func BuildManager( log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start)) }() + newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err) + newNetworkMapBuilder = false + } + + ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",") + expIDs := make(map[string]struct{}, len(ids)) + for _, id := range ids { + expIDs[id] = struct{}{} + } + am := &DefaultAccountManager{ Store: store, geo: geo, @@ -217,6 +237,10 @@ func BuildManager( permissionsManager: permissionsManager, loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, + holder: types.NewHolder(), + + expNewNetworkMap: newNetworkMapBuilder, + expNewNetworkMapAIDs: expIDs, } am.startWarmup(ctx) @@ -395,6 +419,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } go am.UpdateAccountPeers(ctx, accountID) } @@ -1477,6 +1504,10 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } if removedGroupAffectsPeers || newGroupsAffectsPeers { + if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil { + return err + } + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } @@ -2129,6 +2160,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us } if updateNetworkMap { + peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + am.updatePeerInNetworkMapCache(peer.AccountID, peer) am.BufferUpdateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index fe9fb25c6..db377865a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -128,4 +128,5 @@ type Manager interface { GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) AllowSync(string, uint64) bool + RecalculateNetworkMapCache(ctx context.Context, accountId string) error } diff --git a/management/server/account_test.go b/management/server/account_test.go index 07d2f2383..200ba6b98 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1154,7 +1154,16 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } +func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SaveGroup(t) +} + func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + testAccountManager_NetworkUpdates_SaveGroup(t) +} + +func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) group := types.Group{ @@ -1205,7 +1214,16 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePolicy(t) +} + func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + testAccountManager_NetworkUpdates_DeletePolicy(t) +} + +func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { manager, account, peer1, _, _ := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1239,7 +1257,16 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SavePolicy(t) +} + func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + testAccountManager_NetworkUpdates_SavePolicy(t) +} + +func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ @@ -1288,7 +1315,16 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePeer(t) +} + func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + testAccountManager_NetworkUpdates_DeletePeer(t) +} + +func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) group := types.Group{ @@ -1341,7 +1377,16 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeleteGroup(t) +} + func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + testAccountManager_NetworkUpdates_DeleteGroup(t) +} + +func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1377,6 +1422,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } + for drained := false; !drained; { + select { + case <-updMsg: + default: + drained = true + } + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1736,7 +1789,9 @@ func TestAccount_Copy(t *testing.T) { Address: "172.12.6.1/24", }, }, + NetworkMapCache: &types.NetworkMapBuilder{}, } + account.InitOnce() err := hasNilField(account) if err != nil { t.Fatal(err) diff --git a/management/server/dns.go b/management/server/dns.go index e5166ce47..decc5175d 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -117,6 +117,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/group.go b/management/server/group.go index a29c28892..3cf9290a2 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -114,6 +114,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -182,6 +185,9 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -250,6 +256,9 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -318,6 +327,9 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -481,6 +493,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -519,6 +534,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -547,6 +565,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -585,6 +606,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 12b59b691..0a5236cb3 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -7,8 +7,10 @@ import ( "net" "net/netip" "os" + "strconv" "strings" "sync" + "sync/atomic" "time" pb "github.com/golang/protobuf/proto" // nolint @@ -44,6 +46,9 @@ import ( const ( envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS" envBlockPeers = "NB_BLOCK_SAME_PEERS" + envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS" + + defaultSyncLim = 1000 ) // GRPCServer an instance of a Management gRPC API server @@ -63,6 +68,9 @@ type GRPCServer struct { logBlockedPeers bool blockPeersWithSameConfig bool integratedPeerValidator integrated_validator.IntegratedValidator + + syncSem atomic.Int32 + syncLim int32 } // NewServer creates a new Management server @@ -96,6 +104,17 @@ func NewServer( logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true" blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" + syncLim := int32(defaultSyncLim) + if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" { + syncLimParsed, err := strconv.Atoi(syncLimStr) + if err != nil { + log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim) + } else { + //nolint:gosec + syncLim = int32(syncLimParsed) + } + } + return &GRPCServer{ wgKey: key, // peerKey -> event channel @@ -110,6 +129,8 @@ func NewServer( logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, + + syncLim: syncLim, }, nil } @@ -151,6 +172,11 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { + if s.syncSem.Load() >= s.syncLim { + return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") + } + s.syncSem.Add(1) + reqStart := time.Now() ctx := srv.Context() @@ -158,6 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi syncReq := &proto.SyncRequest{} peerKey, err := s.parseRequest(ctx, req, syncReq) if err != nil { + s.syncSem.Add(-1) return err } realIP := getRealIP(ctx) @@ -172,6 +199,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) } if s.blockPeersWithSameConfig { + s.syncSem.Add(-1) return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn) } } @@ -183,27 +211,34 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) - unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) - defer func() { - if unlock != nil { - unlock() - } - }() - accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN") log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String()) if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound { + s.syncSem.Add(-1) return status.Errorf(codes.PermissionDenied, "peer is not registered") } + s.syncSem.Add(-1) return err } + log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) + // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + start := time.Now() + unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) + defer func() { + if unlock != nil { + unlock() + } + }() + log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start)) + log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart)) + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) if syncReq.GetMeta() == nil { @@ -213,21 +248,32 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return mapError(ctx, err) } + log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart)) + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return err } + log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart)) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) + log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart)) + s.ephemeralManager.OnPeerConnected(ctx, peer) + log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart)) + s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) + log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart)) + if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } @@ -237,6 +283,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart)) + s.syncSem.Add(-1) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } @@ -509,10 +557,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) + defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) } + took := time.Since(reqStart) + if took > 7*time.Second { + log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart)) + } }() if loginReq.GetMeta() == nil { @@ -546,9 +600,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, mapError(ctx, err) } + log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart)) + // if the login request contains setup key then it is a registration request if loginReq.GetSetupKey() != "" { s.ephemeralManager.OnPeerDisconnected(ctx, peer) + log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart)) } loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) @@ -557,6 +614,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, status.Errorf(codes.Internal, "failed logging in peer") } + log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart)) + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) @@ -716,6 +775,11 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("toSyncResponse: took %s", time.Since(start)) + }() + response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), NetworkMap: &proto.NetworkMap{ @@ -780,6 +844,11 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("sendInitialSync: took %s", time.Since(start)) + }() + var err error var turnToken *Token @@ -822,10 +891,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } + sendStart := time.Now() err = srv.Send(&proto.EncryptedMessage{ WgPubKey: s.wgKey.PublicKey().String(), Body: encryptedResp, }) + log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart)) if err != nil { log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) diff --git a/management/server/holder.go b/management/server/holder.go new file mode 100644 index 000000000..e8a26e1d0 --- /dev/null +++ b/management/server/holder.go @@ -0,0 +1,39 @@ +package server + +import ( + "github.com/netbirdio/netbird/management/server/types" +) + +func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) { + a := am.holder.GetAccount(account.Id) + if a == nil { + am.holder.AddAccount(account) + return + } + account.NetworkMapCache = a.NetworkMapCache + if account.NetworkMapCache == nil { + return + } + account.NetworkMapCache.UpdateAccountPointer(account) + am.holder.AddAccount(account) +} + +func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account { + return am.holder.GetAccount(accountID) +} + +func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account { + a := am.holder.GetAccount(accountID) + if a != nil { + return a + } + account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure) + if err != nil { + return nil + } + return account +} + +func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) { + am.holder.AddAccount(account) +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index e87043f26..8baffa58b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -125,9 +125,10 @@ type MockAccountManager struct { UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) - AllowSyncFunc func(string, uint64) bool - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + AllowSyncFunc func(string, uint64) bool + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error } func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { @@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { } return true } + +func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error { + if am.RecalculateNetworkMapCacheFunc != nil { + return am.RecalculateNetworkMapCacheFunc(ctx, accountID) + } + return nil +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index f278e1761..ee77a65bb 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -83,6 +83,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -134,6 +137,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -177,6 +183,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/networkmap.go b/management/server/networkmap.go new file mode 100644 index 000000000..2a0627643 --- /dev/null +++ b/management/server/networkmap.go @@ -0,0 +1,80 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" +) + +func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { + am.enrichAccountFromHolder(account) + account.InitNetworkMapBuilderIfNeeded(validatedPeers) +} + +func (am *DefaultAccountManager) getPeerNetworkMapExp( + ctx context.Context, + accountId string, + peerId string, + validatedPeers map[string]struct{}, + customZone nbdns.CustomZone, + metrics *telemetry.AccountManagerMetrics, +) *types.NetworkMap { + account := am.getAccountFromHolderOrInit(accountId) + if account == nil { + log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) + return &types.NetworkMap{ + Network: &types.Network{}, + } + } + return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) +} + +func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { + am.enrichAccountFromHolder(account) + return account.OnPeerAddedUpdNetworkMapCache(peerId) +} + +func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { + am.enrichAccountFromHolder(account) + return account.OnPeerDeletedUpdNetworkMapCache(peerId) +} + +func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { + account := am.getAccountFromHolder(accountId) + if account == nil { + return + } + account.UpdatePeerInNetworkMapCache(peer) +} + +func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { + account.RecalculateNetworkMapCache(validatedPeers) + am.updateAccountInHolder(account) +} + +func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { + if am.experimentalNetworkMap(accountId) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return err + } + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) + return err + } + am.recalculateNetworkMapCache(account, validatedPeers) + } + return nil +} + +func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool { + _, ok := am.expNewNetworkMapAIDs[accountId] + return am.expNewNetworkMap || ok +} diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index b6706ca45..0e6d1631b 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -177,6 +177,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 66484d120..b740610c2 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -157,6 +157,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -257,6 +260,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -331,6 +337,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 82cac424a..89ac419fd 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -119,6 +119,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) + if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -183,6 +186,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) + if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -217,6 +223,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/peer.go b/management/server/peer.go index 30b7073ef..80ab7fc69 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -145,6 +145,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. am.BufferUpdateAccountPeers(ctx, accountID) @@ -321,6 +324,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } + if peerLabelChanged || requiresPeerUpdates { am.UpdateAccountPeers(ctx, accountID) } else if sshChanged { @@ -381,6 +388,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } + if am.experimentalNetworkMap(accountID) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + + if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) + } + + } + if userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -417,7 +436,13 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin return nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + var networkMap *types.NetworkMap + + if am.experimentalNetworkMap(peer.AccountID) { + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -690,6 +715,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + if am.experimentalNetworkMap(accountID) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } + + if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) + } + } + am.BufferUpdateAccountPeers(ctx, accountID) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) @@ -776,6 +812,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } am.BufferUpdateAccountPeers(ctx, accountID) } @@ -783,6 +822,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("handlePeerNotFound: took %s", time.Since(start)) + }() if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. @@ -804,6 +847,11 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("LoginPeer: took %s", time.Since(start)) + }() + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return am.handlePeerLoginNotFound(ctx, login, err) @@ -831,6 +879,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + startTransaction := time.Now() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { @@ -900,8 +949,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) + if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } + startBuffer := time.Now() am.BufferUpdateAccountPeers(ctx, accountID) + log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer)) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) @@ -909,6 +965,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer // getPeerPostureChecks returns the posture checks for the peer. func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getPostureChecks: took %s", time.Since(start)) + }() + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err @@ -1014,9 +1075,17 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, emptyMap, nil, nil } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err + var ( + account *types.Account + err error + ) + if am.experimentalNetworkMap(accountID) { + account = am.getAccountFromHolderOrInit(accountID) + } else { + account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } } approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) @@ -1024,10 +1093,12 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } + startPosture := time.Now() postureChecks, err := am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } + log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) @@ -1037,7 +1108,13 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) + var networkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountID) { + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1167,11 +1244,18 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) - return + var ( + account *types.Account + err error + ) + if am.experimentalNetworkMap(accountID) { + account = am.getAccountFromHolderOrInit(accountID) + } else { + account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) + return + } } globalStart := time.Now() @@ -1204,6 +1288,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + if am.experimentalNetworkMap(accountID) { + am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) + } + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) @@ -1241,7 +1329,13 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start)) start = time.Now() - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + var remotePeerNetworkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountID) { + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + } am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start)) start = time.Now() @@ -1257,7 +1351,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) + am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) }(peer) } @@ -1351,7 +1445,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + var remotePeerNetworkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountId) { + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1368,7 +1468,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) } // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. @@ -1511,6 +1611,10 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p // getPeerGroupIDs returns the IDs of the groups that the peer is part of. func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getPeerGroupIDs: took %s", time.Since(start)) + }() return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } @@ -1580,7 +1684,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto }, }, }, - NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 3b2ab87fc..e151f5abb 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -168,6 +168,15 @@ func TestPeer_SessionExpired(t *testing.T) { } func TestAccountManager_GetNetworkMap(t *testing.T) { + testGetNetworkMapGeneral(t) +} + +func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testGetNetworkMapGeneral(t) +} + +func testGetNetworkMapGeneral(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -1003,7 +1012,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } } +func TestUpdateAccountPeers_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testUpdateAccountPeers(t) +} + func TestUpdateAccountPeers(t *testing.T) { + testUpdateAccountPeers(t) +} + +func testUpdateAccountPeers(t *testing.T) { testCases := []struct { name string peers int @@ -1043,8 +1061,8 @@ func TestUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel assert.Nil(t, update.Update.NetbirdConfig) - assert.Equal(t, tc.peers, len(update.NetworkMap.Peers)) - assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules)) + assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) + assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } }) } @@ -1548,6 +1566,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } diff --git a/management/server/policy.go b/management/server/policy.go index 9e4b3f73a..ff02d46aa 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -77,6 +77,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -120,6 +123,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 943f2a970..f457b994b 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -80,6 +80,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/route.go b/management/server/route.go index 4510426bb..05f7acf9e 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -192,6 +192,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -246,6 +249,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -289,6 +295,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 2b2896572..f16b609f8 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -5,6 +5,9 @@ package settings import ( "context" "fmt" + "time" + + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" @@ -45,6 +48,11 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager { } func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start)) + }() + if userID != activity.SystemInitiator { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4201b68f6..d83d160c3 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -15,6 +16,8 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -46,6 +49,11 @@ const ( accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" + + pgMaxConnections = 30 + pgMinConnections = 1 + pgMaxConnLifetime = 60 * time.Minute + pgHealthCheckPeriod = 1 * time.Minute ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -55,6 +63,7 @@ type SqlStore struct { metrics telemetry.AppMetrics installationPK int storeEngine types.Engine + pool *pgxpool.Pool } type installation struct { @@ -307,6 +316,10 @@ func (s *SqlStore) GetInstallationID() string { } func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("SavePeer: took %s", time.Since(start)) + }() // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID @@ -778,6 +791,13 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { + if s.pool != nil { + return s.getAccountPgx(ctx, accountID) + } + return s.getAccountGorm(ctx, accountID) +} + +func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -788,9 +808,19 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc var account types.Account result := s.db.Model(&account). - Omit("GroupsG"). - Preload("UsersG.PATsG"). // have to be specifies as this is nester reference - Preload(clause.Associations). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). Take(&account, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) @@ -800,70 +830,1147 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc return nil, status.NewGetAccountFromStoreError(result.Error) } - // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us - for i, policy := range account.Policies { - var rules []*types.PolicyRule - err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error - if err != nil { - return nil, status.Errorf(status.NotFound, "rule not found") - } - account.Policies[i].Rules = rules - } - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { - account.SetupKeys[key.Key] = key.Copy() + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key } account.SetupKeysG = nil account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) for _, peer := range account.PeersG { - account.Peers[peer.ID] = peer.Copy() + account.Peers[peer.ID] = &peer } account.PeersG = nil - account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { - user.PATs[pat.ID] = pat.Copy() + pat.UserID = "" + user.PATs[pat.ID] = &pat } - account.Users[user.Id] = user.Copy() + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil } account.UsersG = nil - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { - account.Groups[group.ID] = group.Copy() + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group } account.GroupsG = nil - var groupPeers []types.GroupPeer - s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). - Find(&groupPeers) - for _, groupPeer := range groupPeers { - if group, ok := account.Groups[groupPeer.GroupID]; ok { - group.Peers = append(group.Peers, groupPeer.PeerID) - } else { - log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + account.InitOnce() + return &account, nil +} + +func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for _, route := range account.RoutesG { - account.Routes[route.ID] = route.Copy() + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route } - account.RoutesG = nil account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for _, ns := range account.NameServerGroupsG { - account.NameServerGroups[ns.ID] = ns.Copy() + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil account.NameServerGroupsG = nil + return account, nil +} + +func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { + var account types.Account + account.Network = &types.Network{} + const accountQuery = ` + SELECT + id, created_by, created_at, domain, domain_category, is_domain_primary_account, + -- Embedded Network + network_identifier, network_net, network_dns, network_serial, + -- Embedded DNSSettings + dns_settings_disabled_management_groups, + -- Embedded Settings + settings_peer_login_expiration_enabled, settings_peer_login_expiration, + settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, + settings_regular_users_view_blocked, settings_groups_propagation_enabled, + settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, + settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, + settings_lazy_connection_enabled, + -- Embedded ExtraSettings + settings_extra_peer_approval_enabled, settings_extra_user_approval_required, + settings_extra_integrated_validator, settings_extra_integrated_validator_groups + FROM accounts WHERE id = $1` + + var ( + sPeerLoginExpirationEnabled sql.NullBool + sPeerLoginExpiration sql.NullInt64 + sPeerInactivityExpirationEnabled sql.NullBool + sPeerInactivityExpiration sql.NullInt64 + sRegularUsersViewBlocked sql.NullBool + sGroupsPropagationEnabled sql.NullBool + sJWTGroupsEnabled sql.NullBool + sJWTGroupsClaimName sql.NullString + sJWTAllowGroups sql.NullString + sRoutingPeerDNSResolutionEnabled sql.NullBool + sDNSDomain sql.NullString + sNetworkRange sql.NullString + sLazyConnectionEnabled sql.NullBool + sExtraPeerApprovalEnabled sql.NullBool + sExtraUserApprovalRequired sql.NullBool + sExtraIntegratedValidator sql.NullString + sExtraIntegratedValidatorGroups sql.NullString + networkNet sql.NullString + dnsSettingsDisabledGroups sql.NullString + networkIdentifier sql.NullString + networkDns sql.NullString + networkSerial sql.NullInt64 + createdAt sql.NullTime + ) + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( + &account.Id, &account.CreatedBy, &createdAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, + &networkIdentifier, &networkNet, &networkDns, &networkSerial, + &dnsSettingsDisabledGroups, + &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, + &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, + &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, + &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, + &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, + &sLazyConnectionEnabled, + &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, + &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(err) + } + + account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} + if networkNet.Valid { + _ = json.Unmarshal([]byte(networkNet.String), &account.Network.Net) + } + if createdAt.Valid { + account.CreatedAt = createdAt.Time + } + if dnsSettingsDisabledGroups.Valid { + _ = json.Unmarshal([]byte(dnsSettingsDisabledGroups.String), &account.DNSSettings.DisabledManagementGroups) + } + if networkIdentifier.Valid { + account.Network.Identifier = networkIdentifier.String + } + if networkDns.Valid { + account.Network.Dns = networkDns.String + } + if networkSerial.Valid { + account.Network.Serial = uint64(networkSerial.Int64) + } + if sPeerLoginExpirationEnabled.Valid { + account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool + } + if sPeerLoginExpiration.Valid { + account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) + } + if sPeerInactivityExpirationEnabled.Valid { + account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool + } + if sPeerInactivityExpiration.Valid { + account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) + } + if sRegularUsersViewBlocked.Valid { + account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool + } + if sGroupsPropagationEnabled.Valid { + account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool + } + if sJWTGroupsEnabled.Valid { + account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool + } + if sJWTGroupsClaimName.Valid { + account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String + } + if sRoutingPeerDNSResolutionEnabled.Valid { + account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool + } + if sDNSDomain.Valid { + account.Settings.DNSDomain = sDNSDomain.String + } + if sLazyConnectionEnabled.Valid { + account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool + } + if sJWTAllowGroups.Valid { + _ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups) + } + if sNetworkRange.Valid { + _ = json.Unmarshal([]byte(sNetworkRange.String), &account.Settings.NetworkRange) + } + + if sExtraPeerApprovalEnabled.Valid { + account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool + } + if sExtraUserApprovalRequired.Valid { + account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool + } + if sExtraIntegratedValidator.Valid { + account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String + } + if sExtraIntegratedValidatorGroups.Valid { + _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) + } + account.InitOnce() return &account, nil } +func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, + revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + var skCreatedAt, expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &skCreatedAt, + &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if skCreatedAt.Valid { + sk.CreatedAt = skCreatedAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } + } + return sk, err + }) + if err != nil { + return nil, err + } + return keys, nil +} + +func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, + inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, + meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, + meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, + meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, + peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, + location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var ( + lastLogin, createdAt sql.NullTime + sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + peerStatusLastSeen sql.NullTime + peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool + ip, extraDNS, netAddr, env, flags, files, connIP []byte + metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString + metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString + metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString + locationCountryCode, locationCityName sql.NullString + locationGeoNameID sql.NullInt64 + ) + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, + &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &createdAt, &ephemeral, &extraDNS, + &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, + &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, + &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, + &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, + &locationCountryCode, &locationCityName, &locationGeoNameID) + + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + p.CreatedAt = createdAt.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + if metaHostname.Valid { + p.Meta.Hostname = metaHostname.String + } + if metaGoOS.Valid { + p.Meta.GoOS = metaGoOS.String + } + if metaKernel.Valid { + p.Meta.Kernel = metaKernel.String + } + if metaCore.Valid { + p.Meta.Core = metaCore.String + } + if metaPlatform.Valid { + p.Meta.Platform = metaPlatform.String + } + if metaOS.Valid { + p.Meta.OS = metaOS.String + } + if metaOSVersion.Valid { + p.Meta.OSVersion = metaOSVersion.String + } + if metaWtVersion.Valid { + p.Meta.WtVersion = metaWtVersion.String + } + if metaUIVersion.Valid { + p.Meta.UIVersion = metaUIVersion.String + } + if metaKernelVersion.Valid { + p.Meta.KernelVersion = metaKernelVersion.String + } + if metaSystemSerialNumber.Valid { + p.Meta.SystemSerialNumber = metaSystemSerialNumber.String + } + if metaSystemProductName.Valid { + p.Meta.SystemProductName = metaSystemProductName.String + } + if metaSystemManufacturer.Valid { + p.Meta.SystemManufacturer = metaSystemManufacturer.String + } + if locationCountryCode.Valid { + p.Location.CountryCode = locationCountryCode.String + } + if locationCityName.Valid { + p.Location.CityName = locationCityName.String + } + if locationGeoNameID.Valid { + p.Location.GeoNameID = uint(locationGeoNameID.Int64) + } + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + return nil, err + } + return peers, nil +} + +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + var lastLogin, createdAt sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + u.CreatedAt = createdAt.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } + } + return u, err + }) + if err != nil { + return nil, err + } + return users, nil +} + +func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} + } + return &g, err + }) + if err != nil { + return nil, err + } + return groups, nil +} + +func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + } + return &p, err + }) + if err != nil { + return nil, err + } + return policies, nil +} + +func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + return routes, nil +} + +func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} + } + } + return n, err + }) + if err != nil { + return nil, err + } + return nsgs, nil +} + +func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + return nil, err + } + return checks, nil +} + +func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + return nil, err + } + result := make([]*networkTypes.Network, len(networks)) + for i := range networks { + result[i] = &networks[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + result[i] = &routers[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + result[i] = &resources[i] + } + return result, nil +} + +func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool + var createdAt, updatedAt sql.NullTime + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &onboardingFlowPending, + &signupFormPending, + &createdAt, + &updatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return err + } + if createdAt.Valid { + account.Onboarding.CreatedAt = createdAt.Time + } + if updatedAt.Valid { + account.Onboarding.UpdatedAt = updatedAt.Time + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool + } + return nil +} + +func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed, createdAt sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &createdAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if createdAt.Valid { + pat.CreatedAt = createdAt.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) + if err != nil { + return nil, err + } + return pats, nil +} + +func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { + if len(policyIDs) == 0 { + return nil, nil + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + return nil, err + } + rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + return nil, err + } + return rules, nil +} + +func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { + if len(groupIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + return nil, err + } + groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + return nil, err + } + return groupPeers, nil +} + func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { var user types.User result := s.db.Select("account_id").Take(&user, idQueryCondition, userID) @@ -1054,6 +2161,10 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetAccountNetwork: took %s", time.Since(start)) + }() ctx, cancel := getDebuggingCtx(ctx) defer cancel() @@ -1095,6 +2206,11 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getAccountSettings: took %s", time.Since(start)) + }() + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -1203,8 +2319,41 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe if err != nil { return nil, err } + pool, err := connectToPgDb(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} - return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) +func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = pgMaxConnections + config.MinConns = pgMinConnections + config.MaxConnLifetime = pgMaxConnLifetime + config.HealthCheckPeriod = pgHealthCheckPeriod + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil } // NewMysqlStore creates a new MySQL store. @@ -1273,7 +2422,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data // NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(ctx, dsn, metrics, false) + store, err := NewPostgresqlStoreForTests(ctx, dsn, metrics, false) if err != nil { return nil, err } @@ -1293,6 +2442,50 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } +// used for tests only +func NewPostgresqlStoreForTests(ctx context.Context, dsn string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) { + db, err := gorm.Open(postgres.Open(dsn), getGormConfig()) + if err != nil { + return nil, err + } + pool, err := connectToPgDbForTests(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} + +// used for tests only +func connectToPgDbForTests(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 5 + config.MinConns = 1 + config.MaxConnLifetime = 30 * time.Second + config.HealthCheckPeriod = 10 * time.Second + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil +} + // NewMysqlStoreFromSqlStore restores a store from SqlStore and stores MySQL DB. func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewMysqlStore(ctx, dsn, metrics, false) diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go new file mode 100644 index 000000000..8ff04d68a --- /dev/null +++ b/management/server/store/sql_store_get_account_test.go @@ -0,0 +1,1089 @@ +package store + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/integration_reference" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads +// all fields and nested objects from the database, including deeply nested structures. +func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { + if testing.Short() { + t.Skip("skipping comprehensive test in short mode") + } + + ctx := context.Background() + store, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + // Create comprehensive test data + accountID := "test-account-comprehensive" + userID1 := "user-1" + userID2 := "user-2" + peerID1 := "peer-1" + peerID2 := "peer-2" + peerID3 := "peer-3" + groupID1 := "group-1" + groupID2 := "group-2" + setupKeyID1 := "setup-key-1" + setupKeyID2 := "setup-key-2" + routeID1 := route.ID("route-1") + routeID2 := route.ID("route-2") + nsGroupID1 := "ns-group-1" + nsGroupID2 := "ns-group-2" + policyID1 := "policy-1" + policyID2 := "policy-2" + postureCheckID1 := "posture-check-1" + postureCheckID2 := "posture-check-2" + networkID1 := "network-1" + routerID1 := "router-1" + resourceID1 := "resource-1" + patID1 := "pat-1" + patID2 := "pat-2" + patID3 := "pat-3" + + now := time.Now().UTC().Truncate(time.Second) + lastLogin := now.Add(-24 * time.Hour) + patLastUsed := now.Add(-1 * time.Hour) + + // Build comprehensive account with all fields populated + account := &types.Account{ + Id: accountID, + CreatedBy: userID1, + CreatedAt: now, + Domain: "example.com", + DomainCategory: "business", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "test-network", + Net: net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }, + Dns: "test-dns", + Serial: 42, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"dns-group-1", "dns-group-2"}, + }, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour * 24 * 30, + GroupsPropagationEnabled: true, + JWTGroupsEnabled: true, + JWTGroupsClaimName: "groups", + JWTAllowGroups: []string{"allowed-group-1", "allowed-group-2"}, + RegularUsersViewBlocked: false, + Extra: &types.ExtraSettings{ + PeerApprovalEnabled: true, + IntegratedValidatorGroups: []string{"validator-1"}, + }, + }, + } + + // Create Setup Keys with all fields + setupKey1ExpiresAt := now.Add(30 * 24 * time.Hour) + setupKey1LastUsed := now.Add(-2 * time.Hour) + setupKey1 := &types.SetupKey{ + Id: setupKeyID1, + AccountID: accountID, + Key: "setup-key-secret-1", + Name: "Setup Key 1", + Type: types.SetupKeyReusable, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey1ExpiresAt, + Revoked: false, + UsedTimes: 5, + LastUsed: &setupKey1LastUsed, + AutoGroups: []string{groupID1, groupID2}, + UsageLimit: 100, + Ephemeral: false, + } + + setupKey2ExpiresAt := now.Add(7 * 24 * time.Hour) + setupKey2LastUsed := now.Add(-1 * time.Hour) + setupKey2 := &types.SetupKey{ + Id: setupKeyID2, + AccountID: accountID, + Key: "setup-key-secret-2", + Name: "Setup Key 2 (One-off)", + Type: types.SetupKeyOneOff, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey2ExpiresAt, + Revoked: true, + UsedTimes: 1, + LastUsed: &setupKey2LastUsed, + AutoGroups: []string{}, + UsageLimit: 1, + Ephemeral: true, + } + + account.SetupKeys = map[string]*types.SetupKey{ + setupKey1.Key: setupKey1, + setupKey2.Key: setupKey2, + } + + // Create Peers with comprehensive fields + peer1 := &nbpeer.Peer{ + ID: peerID1, + AccountID: accountID, + Key: "peer-key-1-AAAA", + Name: "Peer 1", + IP: net.ParseIP("100.64.0.1"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer1.example.com", + GoOS: "linux", + Kernel: "5.15.0", + Core: "x86_64", + Platform: "ubuntu", + OS: "Ubuntu 22.04", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + KernelVersion: "5.15.0-78-generic", + OSVersion: "22.04", + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.1.10/32"), Mac: "00:11:22:33:44:55"}, + {NetIP: netip.MustParsePrefix("10.0.0.5/32"), Mac: "00:11:22:33:44:66"}, + }, + SystemSerialNumber: "ABC123", + SystemProductName: "Server Model X", + SystemManufacturer: "Dell Inc.", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-5 * time.Minute), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("203.0.113.10"), + CountryCode: "US", + CityName: "San Francisco", + GeoNameID: 5391959, + }, + SSHEnabled: true, + SSHKey: "ssh-rsa AAAAB3NzaC1...", + UserID: userID1, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + DNSLabel: "peer1", + CreatedAt: now.Add(-30 * 24 * time.Hour), + Ephemeral: false, + } + + peer2 := &nbpeer.Peer{ + ID: peerID2, + AccountID: accountID, + Key: "peer-key-2-BBBB", + Name: "Peer 2", + IP: net.ParseIP("100.64.0.2"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer2.example.com", + GoOS: "darwin", + Kernel: "22.0.0", + Core: "arm64", + Platform: "darwin", + OS: "macOS Ventura", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-1 * time.Hour), + Connected: false, + LoginExpired: true, + RequiresApproval: true, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("198.51.100.20"), + CountryCode: "GB", + CityName: "London", + GeoNameID: 2643743, + }, + SSHEnabled: false, + UserID: userID2, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: true, + DNSLabel: "peer2", + CreatedAt: now.Add(-15 * 24 * time.Hour), + Ephemeral: false, + } + + peer3 := &nbpeer.Peer{ + ID: peerID3, + AccountID: accountID, + Key: "peer-key-3-CCCC", + Name: "Peer 3 (Ephemeral)", + IP: net.ParseIP("100.64.0.3"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer3.example.com", + GoOS: "windows", + Platform: "windows", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-10 * time.Minute), + Connected: true, + }, + DNSLabel: "peer3", + CreatedAt: now.Add(-1 * time.Hour), + Ephemeral: true, + } + + account.Peers = map[string]*nbpeer.Peer{ + peerID1: peer1, + peerID2: peer2, + peerID3: peer3, + } + + // Create Users with PATs + pat1ExpirationDate := now.Add(90 * 24 * time.Hour) + pat1 := &types.PersonalAccessToken{ + ID: patID1, + Name: "PAT 1", + HashedToken: "hashed-token-1", + ExpirationDate: &pat1ExpirationDate, + CreatedAt: now.Add(-10 * 24 * time.Hour), + CreatedBy: userID1, + LastUsed: &patLastUsed, + } + + pat2ExpirationDate := now.Add(30 * 24 * time.Hour) + pat2 := &types.PersonalAccessToken{ + ID: patID2, + Name: "PAT 2", + HashedToken: "hashed-token-2", + ExpirationDate: &pat2ExpirationDate, + CreatedAt: now.Add(-5 * 24 * time.Hour), + CreatedBy: userID1, + } + + pat3ExpirationDate := now.Add(60 * 24 * time.Hour) + pat3 := &types.PersonalAccessToken{ + ID: patID3, + Name: "PAT 3", + HashedToken: "hashed-token-3", + ExpirationDate: &pat3ExpirationDate, + CreatedAt: now.Add(-2 * 24 * time.Hour), + CreatedBy: userID2, + } + + user1 := &types.User{ + Id: userID1, + AccountID: accountID, + Role: types.UserRoleOwner, + IsServiceUser: false, + NonDeletable: true, + AutoGroups: []string{groupID1}, + Issued: types.UserIssuedAPI, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 123, + IntegrationType: "azure_ad", + }, + CreatedAt: now.Add(-60 * 24 * time.Hour), + LastLogin: &lastLogin, + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID1: pat1, + patID2: pat2, + }, + } + + user2 := &types.User{ + Id: userID2, + AccountID: accountID, + Role: types.UserRoleAdmin, + IsServiceUser: true, + NonDeletable: false, + AutoGroups: []string{groupID2}, + Issued: types.UserIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 456, + IntegrationType: "google_workspace", + }, + CreatedAt: now.Add(-30 * 24 * time.Hour), + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID3: pat3, + }, + } + + account.Users = map[string]*types.User{ + userID1: user1, + userID2: user2, + } + + // Create Groups with peers and resources + group1 := &types.Group{ + ID: groupID1, + AccountID: accountID, + Name: "Group 1", + Issued: types.GroupIssuedAPI, + Peers: []string{peerID1, peerID2}, + Resources: []types.Resource{ + { + ID: "resource-1", + Type: types.ResourceTypeHost, + }, + }, + } + + group2 := &types.Group{ + ID: groupID2, + AccountID: accountID, + Name: "Group 2", + Issued: types.GroupIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 789, + IntegrationType: "okta", + }, + Peers: []string{peerID3}, + Resources: []types.Resource{}, + } + + account.Groups = map[string]*types.Group{ + groupID1: group1, + groupID2: group2, + } + + // Create Policies with Rules + policy1 := &types.Policy{ + ID: policyID1, + AccountID: accountID, + Name: "Policy 1", + Description: "Main access policy", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "rule-1", + PolicyID: policyID1, + Name: "Rule 1", + Description: "Allow access", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Ports: []string{}, + PortRanges: []types.RulePortRange{}, + Sources: []string{groupID1}, + Destinations: []string{groupID2}, + }, + { + ID: "rule-2", + PolicyID: policyID1, + Name: "Rule 2", + Description: "Block traffic on specific ports", + Enabled: true, + Action: types.PolicyTrafficActionDrop, + Bidirectional: false, + Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"22", "3389"}, + PortRanges: []types.RulePortRange{ + {Start: 8000, End: 8999}, + }, + Sources: []string{groupID2}, + Destinations: []string{groupID1}, + }, + }, + } + + policy2 := &types.Policy{ + ID: policyID2, + AccountID: accountID, + Name: "Policy 2", + Description: "Secondary policy", + Enabled: false, + Rules: []*types.PolicyRule{ + { + ID: "rule-3", + PolicyID: policyID2, + Name: "Rule 3", + Description: "UDP access", + Enabled: false, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolUDP, + Ports: []string{"53"}, + Sources: []string{groupID1}, + Destinations: []string{groupID1}, + }, + }, + } + + account.Policies = []*types.Policy{policy1, policy2} + + // Create Routes + route1 := &route.Route{ + ID: routeID1, + AccountID: accountID, + Network: netip.MustParsePrefix("10.0.0.0/24"), + NetworkType: route.IPv4Network, + Peer: peerID1, + PeerGroups: []string{}, + Description: "Route 1", + NetID: "net-id-1", + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{groupID1}, + AccessControlGroups: []string{groupID2}, + } + + route2 := &route.Route{ + ID: routeID2, + AccountID: accountID, + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetworkType: route.IPv4Network, + Peer: "", + PeerGroups: []string{groupID2}, + Description: "Route 2 (High Availability)", + NetID: "net-id-2", + Masquerade: false, + Metric: 100, + Enabled: true, + Groups: []string{groupID1, groupID2}, + AccessControlGroups: []string{groupID1}, + } + + account.Routes = map[route.ID]*route.Route{ + routeID1: route1, + routeID2: route2, + } + + // Create NameServer Groups + nsGroup1 := &nbdns.NameServerGroup{ + ID: nsGroupID1, + AccountID: accountID, + Name: "NS Group 1", + Description: "Primary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + { + IP: netip.MustParseAddr("8.8.4.4"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{groupID1, groupID2}, + Domains: []string{"example.com", "test.com"}, + Enabled: true, + Primary: true, + SearchDomainsEnabled: true, + } + + nsGroup2 := &nbdns.NameServerGroup{ + ID: nsGroupID2, + AccountID: accountID, + Name: "NS Group 2", + Description: "Secondary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{}, + Domains: []string{}, + Enabled: false, + Primary: false, + SearchDomainsEnabled: false, + } + + account.NameServerGroups = map[string]*nbdns.NameServerGroup{ + nsGroupID1: nsGroup1, + nsGroupID2: nsGroup2, + } + + // Create Posture Checks + postureCheck1 := &posture.Checks{ + ID: postureCheckID1, + AccountID: accountID, + Name: "Posture Check 1", + Description: "OS version check", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.24.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "16.0", + }, + Darwin: &posture.MinVersionCheck{ + MinVersion: "22.0.0", + }, + }, + }, + } + + postureCheck2 := &posture.Checks{ + ID: postureCheckID2, + AccountID: accountID, + Name: "Posture Check 2", + Description: "Geo location check", + Checks: posture.ChecksDefinition{ + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "US", + CityName: "San Francisco", + }, + { + CountryCode: "GB", + CityName: "London", + }, + }, + Action: "allow", + }, + PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{ + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + Action: "allow", + }, + }, + } + + account.PostureChecks = []*posture.Checks{postureCheck1, postureCheck2} + + // Create Networks + network1 := &networkTypes.Network{ + ID: networkID1, + AccountID: accountID, + Name: "Network 1", + Description: "Primary network", + } + + account.Networks = []*networkTypes.Network{network1} + + // Create Network Routers + router1 := &routerTypes.NetworkRouter{ + ID: routerID1, + AccountID: accountID, + NetworkID: networkID1, + Peer: peerID1, + PeerGroups: []string{}, + Masquerade: true, + Metric: 100, + } + + account.NetworkRouters = []*routerTypes.NetworkRouter{router1} + + // Create Network Resources + resource1 := &resourceTypes.NetworkResource{ + ID: resourceID1, + AccountID: accountID, + NetworkID: networkID1, + Name: "Resource 1", + Description: "Web server", + Prefix: netip.MustParsePrefix("192.168.1.100/32"), + Type: resourceTypes.Host, + } + + account.NetworkResources = []*resourceTypes.NetworkResource{resource1} + + // Create Onboarding + account.Onboarding = types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + SignupFormPending: false, + CreatedAt: now, + UpdatedAt: now, + } + + // Save the account to the database + err = store.SaveAccount(ctx, account) + require.NoError(t, err, "Failed to save comprehensive test account") + + // Retrieve the account from the database + retrievedAccount, err := store.GetAccount(ctx, accountID) + require.NoError(t, err, "Failed to retrieve account") + require.NotNil(t, retrievedAccount, "Retrieved account should not be nil") + + // ========== VALIDATE TOP-LEVEL FIELDS ========== + t.Run("TopLevelFields", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Id, "Account ID mismatch") + assert.Equal(t, userID1, retrievedAccount.CreatedBy, "CreatedBy mismatch") + assert.WithinDuration(t, now, retrievedAccount.CreatedAt, time.Second, "CreatedAt mismatch") + assert.Equal(t, "example.com", retrievedAccount.Domain, "Domain mismatch") + assert.Equal(t, "business", retrievedAccount.DomainCategory, "DomainCategory mismatch") + assert.True(t, retrievedAccount.IsDomainPrimaryAccount, "IsDomainPrimaryAccount should be true") + }) + + // ========== VALIDATE EMBEDDED NETWORK ========== + t.Run("EmbeddedNetwork", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Network, "Network should not be nil") + assert.Equal(t, "test-network", retrievedAccount.Network.Identifier, "Network Identifier mismatch") + assert.Equal(t, "test-dns", retrievedAccount.Network.Dns, "Network DNS mismatch") + assert.Equal(t, uint64(42), retrievedAccount.Network.Serial, "Network Serial mismatch") + + expectedIP := net.ParseIP("100.64.0.0") + assert.True(t, retrievedAccount.Network.Net.IP.Equal(expectedIP), "Network IP mismatch") + expectedMask := net.CIDRMask(10, 32) + assert.Equal(t, expectedMask, retrievedAccount.Network.Net.Mask, "Network Mask mismatch") + }) + + // ========== VALIDATE DNS SETTINGS ========== + t.Run("DNSSettings", func(t *testing.T) { + assert.Len(t, retrievedAccount.DNSSettings.DisabledManagementGroups, 2, "DisabledManagementGroups length mismatch") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-1", "Missing dns-group-1") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-2", "Missing dns-group-2") + }) + + // ========== VALIDATE SETTINGS ========== + t.Run("Settings", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Settings, "Settings should not be nil") + assert.True(t, retrievedAccount.Settings.PeerLoginExpirationEnabled, "PeerLoginExpirationEnabled mismatch") + assert.Equal(t, time.Hour*24*30, retrievedAccount.Settings.PeerLoginExpiration, "PeerLoginExpiration mismatch") + assert.True(t, retrievedAccount.Settings.GroupsPropagationEnabled, "GroupsPropagationEnabled mismatch") + assert.True(t, retrievedAccount.Settings.JWTGroupsEnabled, "JWTGroupsEnabled mismatch") + assert.Equal(t, "groups", retrievedAccount.Settings.JWTGroupsClaimName, "JWTGroupsClaimName mismatch") + assert.Len(t, retrievedAccount.Settings.JWTAllowGroups, 2, "JWTAllowGroups length mismatch") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-1") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-2") + assert.False(t, retrievedAccount.Settings.RegularUsersViewBlocked, "RegularUsersViewBlocked mismatch") + + // Validate Extra Settings + require.NotNil(t, retrievedAccount.Settings.Extra, "Extra settings should not be nil") + assert.True(t, retrievedAccount.Settings.Extra.PeerApprovalEnabled, "PeerApprovalEnabled mismatch") + assert.Len(t, retrievedAccount.Settings.Extra.IntegratedValidatorGroups, 1, "IntegratedValidatorGroups length mismatch") + assert.Equal(t, "validator-1", retrievedAccount.Settings.Extra.IntegratedValidatorGroups[0]) + }) + + // ========== VALIDATE SETUP KEYS ========== + t.Run("SetupKeys", func(t *testing.T) { + require.Len(t, retrievedAccount.SetupKeys, 2, "Should have 2 setup keys") + + // Validate Setup Key 1 + sk1, exists := retrievedAccount.SetupKeys["setup-key-secret-1"] + require.True(t, exists, "Setup key 1 should exist") + assert.Equal(t, "Setup Key 1", sk1.Name, "Setup key 1 name mismatch") + assert.Equal(t, types.SetupKeyReusable, sk1.Type, "Setup key 1 type mismatch") + assert.False(t, sk1.Revoked, "Setup key 1 should not be revoked") + assert.Equal(t, 5, sk1.UsedTimes, "Setup key 1 used times mismatch") + assert.Equal(t, 100, sk1.UsageLimit, "Setup key 1 usage limit mismatch") + assert.False(t, sk1.Ephemeral, "Setup key 1 should not be ephemeral") + assert.Len(t, sk1.AutoGroups, 2, "Setup key 1 auto groups length mismatch") + assert.Contains(t, sk1.AutoGroups, groupID1) + assert.Contains(t, sk1.AutoGroups, groupID2) + + // Validate Setup Key 2 + sk2, exists := retrievedAccount.SetupKeys["setup-key-secret-2"] + require.True(t, exists, "Setup key 2 should exist") + assert.Equal(t, "Setup Key 2 (One-off)", sk2.Name, "Setup key 2 name mismatch") + assert.Equal(t, types.SetupKeyOneOff, sk2.Type, "Setup key 2 type mismatch") + assert.True(t, sk2.Revoked, "Setup key 2 should be revoked") + assert.Equal(t, 1, sk2.UsedTimes, "Setup key 2 used times mismatch") + assert.Equal(t, 1, sk2.UsageLimit, "Setup key 2 usage limit mismatch") + assert.True(t, sk2.Ephemeral, "Setup key 2 should be ephemeral") + assert.Len(t, sk2.AutoGroups, 0, "Setup key 2 should have empty auto groups") + }) + + // ========== VALIDATE PEERS ========== + t.Run("Peers", func(t *testing.T) { + require.Len(t, retrievedAccount.Peers, 3, "Should have 3 peers") + + // Validate Peer 1 + p1, exists := retrievedAccount.Peers[peerID1] + require.True(t, exists, "Peer 1 should exist") + assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch") + assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch") + assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch") + assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch") + assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled") + assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch") + assert.True(t, p1.LoginExpirationEnabled, "Peer 1 login expiration should be enabled") + assert.False(t, p1.Ephemeral, "Peer 1 should not be ephemeral") + assert.Equal(t, "peer1", p1.DNSLabel, "Peer 1 DNS label mismatch") + + // Validate Peer 1 Meta + assert.Equal(t, "peer1.example.com", p1.Meta.Hostname, "Peer 1 hostname mismatch") + assert.Equal(t, "linux", p1.Meta.GoOS, "Peer 1 OS mismatch") + assert.Equal(t, "5.15.0", p1.Meta.Kernel, "Peer 1 kernel mismatch") + assert.Equal(t, "x86_64", p1.Meta.Core, "Peer 1 core mismatch") + assert.Equal(t, "ubuntu", p1.Meta.Platform, "Peer 1 platform mismatch") + assert.Equal(t, "Ubuntu 22.04", p1.Meta.OS, "Peer 1 OS version mismatch") + assert.Equal(t, "0.24.0", p1.Meta.WtVersion, "Peer 1 wt version mismatch") + assert.Equal(t, "ABC123", p1.Meta.SystemSerialNumber, "Peer 1 serial number mismatch") + assert.Equal(t, "Server Model X", p1.Meta.SystemProductName, "Peer 1 product name mismatch") + assert.Equal(t, "Dell Inc.", p1.Meta.SystemManufacturer, "Peer 1 manufacturer mismatch") + + // Validate Network Addresses + assert.Len(t, p1.Meta.NetworkAddresses, 2, "Peer 1 should have 2 network addresses") + assert.Equal(t, netip.MustParsePrefix("192.168.1.10/32"), p1.Meta.NetworkAddresses[0].NetIP, "Network address 1 IP mismatch") + assert.Equal(t, "00:11:22:33:44:55", p1.Meta.NetworkAddresses[0].Mac, "Network address 1 MAC mismatch") + assert.Equal(t, netip.MustParsePrefix("10.0.0.5/32"), p1.Meta.NetworkAddresses[1].NetIP, "Network address 2 IP mismatch") + assert.Equal(t, "00:11:22:33:44:66", p1.Meta.NetworkAddresses[1].Mac, "Network address 2 MAC mismatch") + + // Validate Peer 1 Status + require.NotNil(t, p1.Status, "Peer 1 status should not be nil") + assert.True(t, p1.Status.Connected, "Peer 1 should be connected") + assert.False(t, p1.Status.LoginExpired, "Peer 1 login should not be expired") + assert.False(t, p1.Status.RequiresApproval, "Peer 1 should not require approval") + + // Validate Peer 1 Location + assert.True(t, p1.Location.ConnectionIP.Equal(net.ParseIP("203.0.113.10")), "Peer 1 connection IP mismatch") + assert.Equal(t, "US", p1.Location.CountryCode, "Peer 1 country code mismatch") + assert.Equal(t, "San Francisco", p1.Location.CityName, "Peer 1 city name mismatch") + assert.Equal(t, uint(5391959), p1.Location.GeoNameID, "Peer 1 geo name ID mismatch") + + // Validate Peer 2 + p2, exists := retrievedAccount.Peers[peerID2] + require.True(t, exists, "Peer 2 should exist") + assert.Equal(t, "Peer 2", p2.Name, "Peer 2 name mismatch") + assert.Equal(t, "peer-key-2-BBBB", p2.Key, "Peer 2 key mismatch") + assert.False(t, p2.SSHEnabled, "Peer 2 SSH should be disabled") + assert.False(t, p2.LoginExpirationEnabled, "Peer 2 login expiration should be disabled") + assert.True(t, p2.InactivityExpirationEnabled, "Peer 2 inactivity expiration should be enabled") + + // Validate Peer 2 Status + require.NotNil(t, p2.Status, "Peer 2 status should not be nil") + assert.False(t, p2.Status.Connected, "Peer 2 should not be connected") + assert.True(t, p2.Status.LoginExpired, "Peer 2 login should be expired") + assert.True(t, p2.Status.RequiresApproval, "Peer 2 should require approval") + + // Validate Peer 3 (Ephemeral) + p3, exists := retrievedAccount.Peers[peerID3] + require.True(t, exists, "Peer 3 should exist") + assert.True(t, p3.Ephemeral, "Peer 3 should be ephemeral") + assert.Equal(t, "Peer 3 (Ephemeral)", p3.Name, "Peer 3 name mismatch") + }) + + // ========== VALIDATE USERS ========== + t.Run("Users", func(t *testing.T) { + require.Len(t, retrievedAccount.Users, 2, "Should have 2 users") + + // Validate User 1 + u1, exists := retrievedAccount.Users[userID1] + require.True(t, exists, "User 1 should exist") + assert.Equal(t, types.UserRoleOwner, u1.Role, "User 1 role mismatch") + assert.False(t, u1.IsServiceUser, "User 1 should not be a service user") + assert.True(t, u1.NonDeletable, "User 1 should be non-deletable") + assert.Equal(t, types.UserIssuedAPI, u1.Issued, "User 1 issued type mismatch") + assert.Len(t, u1.AutoGroups, 1, "User 1 auto groups length mismatch") + assert.Contains(t, u1.AutoGroups, groupID1, "User 1 should have group1") + assert.False(t, u1.Blocked, "User 1 should not be blocked") + require.NotNil(t, u1.LastLogin, "User 1 last login should not be nil") + assert.WithinDuration(t, lastLogin, *u1.LastLogin, time.Second, "User 1 last login mismatch") + + // Validate User 1 Integration Reference + assert.Equal(t, 123, u1.IntegrationReference.ID, "User 1 integration ID mismatch") + assert.Equal(t, "azure_ad", u1.IntegrationReference.IntegrationType, "User 1 integration type mismatch") + + // Validate User 1 PATs + require.Len(t, u1.PATs, 2, "User 1 should have 2 PATs") + + pat1Retrieved, exists := u1.PATs[patID1] + require.True(t, exists, "PAT 1 should exist") + assert.Equal(t, "PAT 1", pat1Retrieved.Name, "PAT 1 name mismatch") + assert.Equal(t, "hashed-token-1", pat1Retrieved.HashedToken, "PAT 1 hashed token mismatch") + require.NotNil(t, pat1Retrieved.LastUsed, "PAT 1 last used should not be nil") + assert.WithinDuration(t, patLastUsed, *pat1Retrieved.LastUsed, time.Second, "PAT 1 last used mismatch") + assert.Equal(t, userID1, pat1Retrieved.CreatedBy, "PAT 1 created by mismatch") + assert.Empty(t, pat1Retrieved.UserID, "PAT 1 UserID should be cleared") + + pat2Retrieved, exists := u1.PATs[patID2] + require.True(t, exists, "PAT 2 should exist") + assert.Equal(t, "PAT 2", pat2Retrieved.Name, "PAT 2 name mismatch") + assert.Nil(t, pat2Retrieved.LastUsed, "PAT 2 last used should be nil") + + // Validate User 2 + u2, exists := retrievedAccount.Users[userID2] + require.True(t, exists, "User 2 should exist") + assert.Equal(t, types.UserRoleAdmin, u2.Role, "User 2 role mismatch") + assert.True(t, u2.IsServiceUser, "User 2 should be a service user") + assert.False(t, u2.NonDeletable, "User 2 should be deletable") + assert.Equal(t, types.UserIssuedIntegration, u2.Issued, "User 2 issued type mismatch") + assert.Equal(t, "google_workspace", u2.IntegrationReference.IntegrationType, "User 2 integration type mismatch") + + // Validate User 2 PATs + require.Len(t, u2.PATs, 1, "User 2 should have 1 PAT") + pat3Retrieved, exists := u2.PATs[patID3] + require.True(t, exists, "PAT 3 should exist") + assert.Equal(t, "PAT 3", pat3Retrieved.Name, "PAT 3 name mismatch") + }) + + // ========== VALIDATE GROUPS ========== + t.Run("Groups", func(t *testing.T) { + require.Len(t, retrievedAccount.Groups, 2, "Should have 2 groups") + + // Validate Group 1 + g1, exists := retrievedAccount.Groups[groupID1] + require.True(t, exists, "Group 1 should exist") + assert.Equal(t, "Group 1", g1.Name, "Group 1 name mismatch") + assert.Equal(t, types.GroupIssuedAPI, g1.Issued, "Group 1 issued type mismatch") + assert.Len(t, g1.Peers, 2, "Group 1 should have 2 peers") + assert.Contains(t, g1.Peers, peerID1, "Group 1 should contain peer 1") + assert.Contains(t, g1.Peers, peerID2, "Group 1 should contain peer 2") + + // Validate Group 1 Resources + assert.Len(t, g1.Resources, 1, "Group 1 should have 1 resource") + assert.Equal(t, "resource-1", g1.Resources[0].ID, "Group 1 resource ID mismatch") + assert.Equal(t, types.ResourceTypeHost, g1.Resources[0].Type, "Group 1 resource type mismatch") + + // Validate Group 2 + g2, exists := retrievedAccount.Groups[groupID2] + require.True(t, exists, "Group 2 should exist") + assert.Equal(t, "Group 2", g2.Name, "Group 2 name mismatch") + assert.Equal(t, types.GroupIssuedIntegration, g2.Issued, "Group 2 issued type mismatch") + assert.Len(t, g2.Peers, 1, "Group 2 should have 1 peer") + assert.Contains(t, g2.Peers, peerID3, "Group 2 should contain peer 3") + assert.Len(t, g2.Resources, 0, "Group 2 should have 0 resources") + + // Validate Group 2 Integration Reference + assert.Equal(t, 789, g2.IntegrationReference.ID, "Group 2 integration ID mismatch") + assert.Equal(t, "okta", g2.IntegrationReference.IntegrationType, "Group 2 integration type mismatch") + }) + + // ========== VALIDATE POLICIES ========== + t.Run("Policies", func(t *testing.T) { + require.Len(t, retrievedAccount.Policies, 2, "Should have 2 policies") + + // Validate Policy 1 + pol1 := retrievedAccount.Policies[0] + if pol1.ID != policyID1 { + pol1 = retrievedAccount.Policies[1] + } + assert.Equal(t, policyID1, pol1.ID, "Policy 1 ID mismatch") + assert.Equal(t, "Policy 1", pol1.Name, "Policy 1 name mismatch") + assert.Equal(t, "Main access policy", pol1.Description, "Policy 1 description mismatch") + assert.True(t, pol1.Enabled, "Policy 1 should be enabled") + + // Validate Policy 1 Rules + require.Len(t, pol1.Rules, 2, "Policy 1 should have 2 rules") + + rule1 := pol1.Rules[0] + assert.Equal(t, "Rule 1", rule1.Name, "Rule 1 name mismatch") + assert.Equal(t, "Allow access", rule1.Description, "Rule 1 description mismatch") + assert.True(t, rule1.Enabled, "Rule 1 should be enabled") + assert.Equal(t, types.PolicyTrafficActionAccept, rule1.Action, "Rule 1 action mismatch") + assert.True(t, rule1.Bidirectional, "Rule 1 should be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolALL, rule1.Protocol, "Rule 1 protocol mismatch") + assert.Len(t, rule1.Sources, 1, "Rule 1 sources length mismatch") + assert.Contains(t, rule1.Sources, groupID1, "Rule 1 should have group1 as source") + assert.Len(t, rule1.Destinations, 1, "Rule 1 destinations length mismatch") + assert.Contains(t, rule1.Destinations, groupID2, "Rule 1 should have group2 as destination") + + rule2 := pol1.Rules[1] + assert.Equal(t, "Rule 2", rule2.Name, "Rule 2 name mismatch") + assert.Equal(t, types.PolicyTrafficActionDrop, rule2.Action, "Rule 2 action mismatch") + assert.False(t, rule2.Bidirectional, "Rule 2 should not be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolTCP, rule2.Protocol, "Rule 2 protocol mismatch") + assert.Len(t, rule2.Ports, 2, "Rule 2 ports length mismatch") + assert.Contains(t, rule2.Ports, "22", "Rule 2 should have port 22") + assert.Contains(t, rule2.Ports, "3389", "Rule 2 should have port 3389") + assert.Len(t, rule2.PortRanges, 1, "Rule 2 port ranges length mismatch") + assert.Equal(t, uint16(8000), rule2.PortRanges[0].Start, "Rule 2 port range start mismatch") + assert.Equal(t, uint16(8999), rule2.PortRanges[0].End, "Rule 2 port range end mismatch") + + // Validate Policy 2 + pol2 := retrievedAccount.Policies[1] + if pol2.ID != policyID2 { + pol2 = retrievedAccount.Policies[0] + } + assert.Equal(t, policyID2, pol2.ID, "Policy 2 ID mismatch") + assert.Equal(t, "Policy 2", pol2.Name, "Policy 2 name mismatch") + assert.False(t, pol2.Enabled, "Policy 2 should be disabled") + require.Len(t, pol2.Rules, 1, "Policy 2 should have 1 rule") + + rule3 := pol2.Rules[0] + assert.Equal(t, "Rule 3", rule3.Name, "Rule 3 name mismatch") + assert.False(t, rule3.Enabled, "Rule 3 should be disabled") + assert.Equal(t, types.PolicyRuleProtocolUDP, rule3.Protocol, "Rule 3 protocol mismatch") + }) + + // ========== VALIDATE ROUTES ========== + t.Run("Routes", func(t *testing.T) { + require.Len(t, retrievedAccount.Routes, 2, "Should have 2 routes") + + // Validate Route 1 + r1, exists := retrievedAccount.Routes[routeID1] + require.True(t, exists, "Route 1 should exist") + assert.Equal(t, "Route 1", r1.Description, "Route 1 description mismatch") + assert.Equal(t, route.IPv4Network, r1.NetworkType, "Route 1 network type mismatch") + assert.Equal(t, peerID1, r1.Peer, "Route 1 peer mismatch") + assert.Empty(t, r1.PeerGroups, "Route 1 peer groups should be empty") + assert.Equal(t, route.NetID("net-id-1"), r1.NetID, "Route 1 net ID mismatch") + assert.True(t, r1.Masquerade, "Route 1 masquerade should be enabled") + assert.Equal(t, 9999, r1.Metric, "Route 1 metric mismatch") + assert.True(t, r1.Enabled, "Route 1 should be enabled") + assert.Len(t, r1.Groups, 1, "Route 1 groups length mismatch") + assert.Contains(t, r1.Groups, groupID1, "Route 1 should have group1") + assert.Len(t, r1.AccessControlGroups, 1, "Route 1 ACL groups length mismatch") + assert.Contains(t, r1.AccessControlGroups, groupID2, "Route 1 should have group2 in ACL") + + // Validate Route 1 Network CIDR + assert.Equal(t, "10.0.0.0/24", r1.Network.String(), "Route 1 network CIDR mismatch") + + // Validate Route 2 + r2, exists := retrievedAccount.Routes[routeID2] + require.True(t, exists, "Route 2 should exist") + assert.Equal(t, "Route 2 (High Availability)", r2.Description, "Route 2 description mismatch") + assert.Empty(t, r2.Peer, "Route 2 peer should be empty") + assert.Len(t, r2.PeerGroups, 1, "Route 2 peer groups length mismatch") + assert.Contains(t, r2.PeerGroups, groupID2, "Route 2 should have group2 as peer group") + assert.False(t, r2.Masquerade, "Route 2 masquerade should be disabled") + assert.Equal(t, 100, r2.Metric, "Route 2 metric mismatch") + assert.Equal(t, "192.168.1.0/24", r2.Network.String(), "Route 2 network CIDR mismatch") + }) + + // ========== VALIDATE NAME SERVER GROUPS ========== + t.Run("NameServerGroups", func(t *testing.T) { + require.Len(t, retrievedAccount.NameServerGroups, 2, "Should have 2 nameserver groups") + + // Validate NS Group 1 + nsg1, exists := retrievedAccount.NameServerGroups[nsGroupID1] + require.True(t, exists, "NS Group 1 should exist") + assert.Equal(t, "NS Group 1", nsg1.Name, "NS Group 1 name mismatch") + assert.Equal(t, "Primary nameservers", nsg1.Description, "NS Group 1 description mismatch") + assert.True(t, nsg1.Enabled, "NS Group 1 should be enabled") + assert.True(t, nsg1.Primary, "NS Group 1 should be primary") + assert.True(t, nsg1.SearchDomainsEnabled, "NS Group 1 search domains should be enabled") + assert.Empty(t, nsg1.AccountID, "NS Group 1 AccountID should be cleared") + + // Validate NS Group 1 NameServers + require.Len(t, nsg1.NameServers, 2, "NS Group 1 should have 2 nameservers") + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), nsg1.NameServers[0].IP, "NS Group 1 nameserver 1 IP mismatch") + assert.Equal(t, nbdns.UDPNameServerType, nsg1.NameServers[0].NSType, "NS Group 1 nameserver 1 type mismatch") + assert.Equal(t, 53, nsg1.NameServers[0].Port, "NS Group 1 nameserver 1 port mismatch") + assert.Equal(t, netip.MustParseAddr("8.8.4.4"), nsg1.NameServers[1].IP, "NS Group 1 nameserver 2 IP mismatch") + + // Validate NS Group 1 Groups and Domains + assert.Len(t, nsg1.Groups, 2, "NS Group 1 groups length mismatch") + assert.Contains(t, nsg1.Groups, groupID1, "NS Group 1 should have group1") + assert.Contains(t, nsg1.Groups, groupID2, "NS Group 1 should have group2") + assert.Len(t, nsg1.Domains, 2, "NS Group 1 domains length mismatch") + assert.Contains(t, nsg1.Domains, "example.com", "NS Group 1 should have example.com domain") + assert.Contains(t, nsg1.Domains, "test.com", "NS Group 1 should have test.com domain") + + // Validate NS Group 2 + nsg2, exists := retrievedAccount.NameServerGroups[nsGroupID2] + require.True(t, exists, "NS Group 2 should exist") + assert.Equal(t, "NS Group 2", nsg2.Name, "NS Group 2 name mismatch") + assert.False(t, nsg2.Enabled, "NS Group 2 should be disabled") + assert.False(t, nsg2.Primary, "NS Group 2 should not be primary") + assert.False(t, nsg2.SearchDomainsEnabled, "NS Group 2 search domains should be disabled") + assert.Len(t, nsg2.NameServers, 1, "NS Group 2 should have 1 nameserver") + assert.Len(t, nsg2.Groups, 0, "NS Group 2 should have empty groups") + assert.Len(t, nsg2.Domains, 0, "NS Group 2 should have empty domains") + }) + + // ========== VALIDATE POSTURE CHECKS ========== + t.Run("PostureChecks", func(t *testing.T) { + require.Len(t, retrievedAccount.PostureChecks, 2, "Should have 2 posture checks") + + // Find posture checks by ID + var pc1, pc2 *posture.Checks + for _, pc := range retrievedAccount.PostureChecks { + if pc.ID == postureCheckID1 { + pc1 = pc + } else if pc.ID == postureCheckID2 { + pc2 = pc + } + } + + // Validate Posture Check 1 + require.NotNil(t, pc1, "Posture check 1 should exist") + assert.Equal(t, "Posture Check 1", pc1.Name, "Posture check 1 name mismatch") + assert.Equal(t, "OS version check", pc1.Description, "Posture check 1 description mismatch") + + // Validate NB Version Check + require.NotNil(t, pc1.Checks.NBVersionCheck, "NB version check should not be nil") + assert.Equal(t, "0.24.0", pc1.Checks.NBVersionCheck.MinVersion, "NB version check min version mismatch") + + // Validate OS Version Check + require.NotNil(t, pc1.Checks.OSVersionCheck, "OS version check should not be nil") + require.NotNil(t, pc1.Checks.OSVersionCheck.Ios, "iOS version check should not be nil") + assert.Equal(t, "16.0", pc1.Checks.OSVersionCheck.Ios.MinVersion, "iOS min version mismatch") + require.NotNil(t, pc1.Checks.OSVersionCheck.Darwin, "Darwin version check should not be nil") + assert.Equal(t, "22.0.0", pc1.Checks.OSVersionCheck.Darwin.MinVersion, "Darwin min version mismatch") + + // Validate Posture Check 2 + require.NotNil(t, pc2, "Posture check 2 should exist") + assert.Equal(t, "Posture Check 2", pc2.Name, "Posture check 2 name mismatch") + + // Validate Geo Location Check + require.NotNil(t, pc2.Checks.GeoLocationCheck, "Geo location check should not be nil") + assert.Equal(t, "allow", pc2.Checks.GeoLocationCheck.Action, "Geo location action mismatch") + assert.Len(t, pc2.Checks.GeoLocationCheck.Locations, 2, "Geo location check should have 2 locations") + assert.Equal(t, "US", pc2.Checks.GeoLocationCheck.Locations[0].CountryCode, "Location 1 country code mismatch") + assert.Equal(t, "San Francisco", pc2.Checks.GeoLocationCheck.Locations[0].CityName, "Location 1 city name mismatch") + assert.Equal(t, "GB", pc2.Checks.GeoLocationCheck.Locations[1].CountryCode, "Location 2 country code mismatch") + assert.Equal(t, "London", pc2.Checks.GeoLocationCheck.Locations[1].CityName, "Location 2 city name mismatch") + + // Validate Peer Network Range Check + require.NotNil(t, pc2.Checks.PeerNetworkRangeCheck, "Peer network range check should not be nil") + assert.Equal(t, "allow", pc2.Checks.PeerNetworkRangeCheck.Action, "Peer network range action mismatch") + assert.Len(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, 2, "Peer network range check should have 2 ranges") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("192.168.0.0/16"), "Should have 192.168.0.0/16 range") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("10.0.0.0/8"), "Should have 10.0.0.0/8 range") + }) + + // ========== VALIDATE NETWORKS ========== + t.Run("Networks", func(t *testing.T) { + require.Len(t, retrievedAccount.Networks, 1, "Should have 1 network") + + net1 := retrievedAccount.Networks[0] + assert.Equal(t, networkID1, net1.ID, "Network 1 ID mismatch") + assert.Equal(t, "Network 1", net1.Name, "Network 1 name mismatch") + assert.Equal(t, "Primary network", net1.Description, "Network 1 description mismatch") + }) + + // ========== VALIDATE NETWORK ROUTERS ========== + t.Run("NetworkRouters", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkRouters, 1, "Should have 1 network router") + + router := retrievedAccount.NetworkRouters[0] + assert.Equal(t, routerID1, router.ID, "Router 1 ID mismatch") + assert.Equal(t, networkID1, router.NetworkID, "Router 1 network ID mismatch") + assert.Equal(t, peerID1, router.Peer, "Router 1 peer mismatch") + assert.Empty(t, router.PeerGroups, "Router 1 peer groups should be empty") + assert.True(t, router.Masquerade, "Router 1 masquerade should be enabled") + assert.Equal(t, 100, router.Metric, "Router 1 metric mismatch") + }) + + // ========== VALIDATE NETWORK RESOURCES ========== + t.Run("NetworkResources", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkResources, 1, "Should have 1 network resource") + + res := retrievedAccount.NetworkResources[0] + assert.Equal(t, resourceID1, res.ID, "Resource 1 ID mismatch") + assert.Equal(t, networkID1, res.NetworkID, "Resource 1 network ID mismatch") + assert.Equal(t, "Resource 1", res.Name, "Resource 1 name mismatch") + assert.Equal(t, "Web server", res.Description, "Resource 1 description mismatch") + assert.Equal(t, netip.MustParsePrefix("192.168.1.100/32"), res.Prefix, "Resource 1 prefix mismatch") + assert.Equal(t, resourceTypes.Host, res.Type, "Resource 1 type mismatch") + }) + + // ========== VALIDATE ONBOARDING ========== + t.Run("Onboarding", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Onboarding.AccountID, "Onboarding account ID mismatch") + assert.True(t, retrievedAccount.Onboarding.OnboardingFlowPending, "Onboarding flow should be pending") + assert.False(t, retrievedAccount.Onboarding.SignupFormPending, "Signup form should not be pending") + assert.WithinDuration(t, now, retrievedAccount.Onboarding.CreatedAt, time.Second, "Onboarding created at mismatch") + }) + + t.Log("✅ All comprehensive account field validations passed!") +} diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go new file mode 100644 index 000000000..350a1da83 --- /dev/null +++ b/management/server/store/sqlstore_bench_test.go @@ -0,0 +1,951 @@ +package store + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sort" + "sync" + "testing" + "time" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/jackc/pgx/v5/pgxpool" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" +) + +func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Omit("GroupsG"). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload(clause.Associations). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "rule not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + var groupPeers []types.GroupPeer + s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + Find(&groupPeers) + for _, groupPeer := range groupPeers { + if group, ok := account.Groups[groupPeer.GroupID]; ok { + group.Peers = append(group.Peers, groupPeer.PeerID) + } else { + log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + } + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = &peer + } + account.PeersG = nil + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + pat.UserID = "" + user.PATs[pat.ID] = &pat + } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil + } + account.UsersG = nil + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group + } + account.GroupsG = nil + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + return &account, nil +} + +func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 12 + config.MinConns = 2 + config.MaxConnLifetime = time.Hour + config.HealthCheckPeriod = time.Minute + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + return pool, nil +} + +func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { + cleanup, dsn, err := testutil.CreatePostgresTestContainer() + if err != nil { + b.Fatalf("failed to create test container: %v", err) + } + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + pool, err := connectDBforTest(context.Background(), dsn) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + models := []interface{}{ + &types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, + &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.Policy{}, &types.PolicyRule{}, &route.Route{}, + &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, + &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &types.AccountOnboarding{}, + } + + for i := len(models) - 1; i >= 0; i-- { + err := db.Migrator().DropTable(models[i]) + if err != nil { + b.Fatalf("failed to drop table: %v", err) + } + } + + err = db.AutoMigrate(models...) + if err != nil { + b.Fatalf("failed to migrate database: %v", err) + } + + store := &SqlStore{ + db: db, + pool: pool, + } + + const ( + accountID = "benchmark-account-id" + numUsers = 20 + numPatsPerUser = 3 + numSetupKeys = 25 + numPeers = 200 + numGroups = 30 + numPolicies = 50 + numRulesPerPolicy = 10 + numRoutes = 40 + numNSGroups = 10 + numPostureChecks = 15 + numNetworks = 5 + numNetworkRouters = 5 + numNetworkResources = 10 + ) + + _, ipNet, _ := net.ParseCIDR("100.64.0.0/10") + acc := types.Account{ + Id: accountID, + CreatedBy: "benchmark-user", + CreatedAt: time.Now(), + Domain: "benchmark.com", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "benchmark-net", + Net: *ipNet, + Serial: 1, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"group-disabled-1"}, + }, + Settings: &types.Settings{}, + } + if err := db.Create(&acc).Error; err != nil { + b.Fatalf("create account: %v", err) + } + + var setupKeys []types.SetupKey + for i := 0; i < numSetupKeys; i++ { + setupKeys = append(setupKeys, types.SetupKey{ + Id: fmt.Sprintf("keyid-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("key-%d", i), + Name: fmt.Sprintf("Benchmark Key %d", i), + ExpiresAt: &time.Time{}, + }) + } + if err := db.Create(&setupKeys).Error; err != nil { + b.Fatalf("create setup keys: %v", err) + } + + var peers []nbpeer.Peer + for i := 0; i < numPeers; i++ { + peers = append(peers, nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("peerkey-%d", i), + IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)), + Name: fmt.Sprintf("peer-name-%d", i), + Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()}, + }) + } + if err := db.Create(&peers).Error; err != nil { + b.Fatalf("create peers: %v", err) + } + + for i := 0; i < numUsers; i++ { + userID := fmt.Sprintf("user-%d", i) + user := types.User{Id: userID, AccountID: accountID} + if err := db.Create(&user).Error; err != nil { + b.Fatalf("create user %s: %v", userID, err) + } + + var pats []types.PersonalAccessToken + for j := 0; j < numPatsPerUser; j++ { + pats = append(pats, types.PersonalAccessToken{ + ID: fmt.Sprintf("pat-%d-%d", i, j), + UserID: userID, + Name: fmt.Sprintf("PAT %d for User %d", j, i), + }) + } + if err := db.Create(&pats).Error; err != nil { + b.Fatalf("create pats for user %s: %v", userID, err) + } + } + + var groups []*types.Group + for i := 0; i < numGroups; i++ { + groups = append(groups, &types.Group{ + ID: fmt.Sprintf("group-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), + }) + } + if err := db.Create(&groups).Error; err != nil { + b.Fatalf("create groups: %v", err) + } + + for i := 0; i < numPolicies; i++ { + policyID := fmt.Sprintf("policy-%d", i) + policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true} + if err := db.Create(&policy).Error; err != nil { + b.Fatalf("create policy %s: %v", policyID, err) + } + + var rules []*types.PolicyRule + for j := 0; j < numRulesPerPolicy; j++ { + rules = append(rules, &types.PolicyRule{ + ID: fmt.Sprintf("rule-%d-%d", i, j), + PolicyID: policyID, + Name: fmt.Sprintf("Rule %d for Policy %d", j, i), + Enabled: true, + Protocol: "all", + }) + } + if err := db.Create(&rules).Error; err != nil { + b.Fatalf("create rules for policy %s: %v", policyID, err) + } + } + + var routes []route.Route + for i := 0; i < numRoutes; i++ { + routes = append(routes, route.Route{ + ID: route.ID(fmt.Sprintf("route-%d", i)), + AccountID: accountID, + Description: fmt.Sprintf("Route %d", i), + Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)), + Enabled: true, + }) + } + if err := db.Create(&routes).Error; err != nil { + b.Fatalf("create routes: %v", err) + } + + var nsGroups []nbdns.NameServerGroup + for i := 0; i < numNSGroups; i++ { + nsGroups = append(nsGroups, nbdns.NameServerGroup{ + ID: fmt.Sprintf("nsg-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("NS Group %d", i), + Description: "Benchmark NS Group", + Enabled: true, + }) + } + if err := db.Create(&nsGroups).Error; err != nil { + b.Fatalf("create nsgroups: %v", err) + } + + var postureChecks []*posture.Checks + for i := 0; i < numPostureChecks; i++ { + postureChecks = append(postureChecks, &posture.Checks{ + ID: fmt.Sprintf("pc-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Posture Check %d", i), + }) + } + if err := db.Create(&postureChecks).Error; err != nil { + b.Fatalf("create posture checks: %v", err) + } + + var networks []*networkTypes.Network + for i := 0; i < numNetworks; i++ { + networks = append(networks, &networkTypes.Network{ + ID: fmt.Sprintf("nettype-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Network Type %d", i), + }) + } + if err := db.Create(&networks).Error; err != nil { + b.Fatalf("create networks: %v", err) + } + + var networkRouters []*routerTypes.NetworkRouter + for i := 0; i < numNetworkRouters; i++ { + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Peer: peers[i%numPeers].ID, + }) + } + if err := db.Create(&networkRouters).Error; err != nil { + b.Fatalf("create network routers: %v", err) + } + + var networkResources []*resourceTypes.NetworkResource + for i := 0; i < numNetworkResources; i++ { + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("resource-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Name: fmt.Sprintf("Resource %d", i), + }) + } + if err := db.Create(&networkResources).Error; err != nil { + b.Fatalf("create network resources: %v", err) + } + + onboarding := types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + } + if err := db.Create(&onboarding).Error; err != nil { + b.Fatalf("create onboarding: %v", err) + } + + return store, cleanup, accountID +} + +func BenchmarkGetAccount(b *testing.B) { + store, cleanup, accountID := setupBenchmarkDB(b) + defer cleanup() + ctx := context.Background() + b.ResetTimer() + b.ReportAllocs() + b.Run("old", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountSlow(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountSlow failed: %v", err) + } + } + }) + b.Run("gorm opt", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountGormOpt(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountFast failed: %v", err) + } + } + }) + b.Run("raw", func(b *testing.B) { + for range b.N { + _, err := store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountPureSQL failed: %v", err) + } + } + }) + store.pool.Close() +} + +func TestAccountEquivalence(t *testing.T) { + store, cleanup, accountID := setupBenchmarkDB(t) + defer cleanup() + ctx := context.Background() + + type getAccountFunc func(context.Context, string) (*types.Account, error) + + tests := []struct { + name string + expectedF getAccountFunc + actualF getAccountFunc + }{ + {"old vs new", store.GetAccountSlow, store.GetAccountGormOpt}, + {"old vs raw", store.GetAccountSlow, store.GetAccount}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, errOld := tt.expectedF(ctx, accountID) + assert.NoError(t, errOld, "expected function should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := tt.actualF(ctx, accountID) + assert.NoError(t, errNew, "actual function should not return an error") + assert.NotNil(t, actual, "actual should not be nil") + testAccountEquivalence(t, expected, actual) + }) + } + + expected, errOld := store.GetAccountSlow(ctx, accountID) + assert.NoError(t, errOld, "GetAccountSlow should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := store.GetAccount(ctx, accountID) + assert.NoError(t, errNew, "GetAccount (new) should not return an error") + assert.NotNil(t, actual, "actual should not be nil") +} + +func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { + assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") + assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second") + assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal") + assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal") + assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal") + assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal") + assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal") + assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal") + + assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements") + for key, oldVal := range expected.SetupKeys { + newVal, ok := actual.SetupKeys[key] + assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key) + } + + assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements") + for key, oldVal := range expected.Peers { + newVal, ok := actual.Peers[key] + assert.True(t, ok, "Peer with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements") + for key, oldUser := range expected.Users { + newUser, ok := actual.Users[key] + assert.True(t, ok, "User with ID '%s' should exist in new account", key) + + assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key) + for patKey, oldPAT := range oldUser.PATs { + newPAT, patOk := newUser.PATs[patKey] + assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key) + assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key) + } + + oldUser.PATs = nil + newUser.PATs = nil + assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key) + } + + assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements") + for key, oldVal := range expected.Groups { + newVal, ok := actual.Groups[key] + assert.True(t, ok, "Group with ID '%s' should exist in new account", key) + sort.Strings(oldVal.Peers) + sort.Strings(newVal.Peers) + assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements") + for key, oldVal := range expected.Routes { + newVal, ok := actual.Routes[key] + assert.True(t, ok, "Route with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key) + } + + assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements") + for key, oldVal := range expected.NameServerGroups { + newVal, ok := actual.NameServerGroups[key] + assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements") + sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID }) + sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID }) + for i := range expected.Policies { + sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID }) + sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID }) + assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID) + } + + assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements") + sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID }) + sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID }) + for i := range expected.PostureChecks { + assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID) + } + + assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements") + sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID }) + sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID }) + for i := range expected.Networks { + assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID) + } + + assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements") + sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID }) + sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID }) + for i := range expected.NetworkRouters { + assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID) + } + + assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements") + sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID }) + sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID }) + for i := range expected.NetworkResources { + assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID) + } +} + +func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return account, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 21b660d96..007e2b739 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -468,6 +468,9 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) closeConnection := func() { cleanup() store.Close(ctx) + if store.pool != nil { + store.pool.Close() + } } return store, closeConnection, nil @@ -487,12 +490,18 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err) } dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB, _ := db.DB() + if sqlDB != nil { + sqlDB.Close() + } + if err != nil { return nil, nil, err } @@ -519,12 +528,22 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) } - db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err) } + sqlDB, err := db.DB() + if err != nil { + return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err) + } + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(1) + dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB.Close() + if err != nil { return nil, nil, err } @@ -537,6 +556,31 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return store, cleanup, nil } +func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) { + var db *gorm.DB + var err error + + for i := range maxRetries { + switch engine { + case types.PostgresStoreEngine: + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + case types.MysqlStoreEngine: + db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + } + + if err == nil { + return db, nil + } + + if i < maxRetries-1 { + waitTime := time.Duration(100*(i+1)) * time.Millisecond + time.Sleep(waitTime) + } + } + + return nil, err +} + func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) { dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_")) @@ -544,21 +588,63 @@ func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func( return "", nil, fmt.Errorf("failed to create database: %v", err) } - var err error + originalDSN := dsn + cleanup := func() { + var dropDB *gorm.DB + var err error + switch engine { case types.PostgresStoreEngine: - err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error + dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error + case types.MysqlStoreEngine: - // err = killMySQLConnections(dsn, dbName) - err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error + dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error } + if err != nil { log.Errorf("failed to drop database %s: %v", dbName, err) - panic(err) } - sqlDB, _ := db.DB() - _ = sqlDB.Close() } return replaceDBName(dsn, dbName), cleanup, nil diff --git a/management/server/types/account.go b/management/server/types/account.go index 50bdc6ab3..dd6052498 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -8,6 +8,7 @@ import ( "slices" "strconv" "strings" + "sync" "time" "github.com/hashicorp/go-multierror" @@ -87,6 +88,13 @@ type Account struct { NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` + + NetworkMapCache *NetworkMapBuilder `gorm:"-"` + nmapInitOnce *sync.Once `gorm:"-"` +} + +func (a *Account) InitOnce() { + a.nmapInitOnce = &sync.Once{} } // this class is used by gorm only @@ -257,6 +265,9 @@ func (a *Account) GetPeerNetworkMap( metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetPeerNetworkMap: took %s", time.Since(start)) + }() peer := a.Peers[peerID] if peer == nil { @@ -890,6 +901,8 @@ func (a *Account) Copy() *Account { NetworkRouters: networkRouters, NetworkResources: networkResources, Onboarding: a.Onboarding, + NetworkMapCache: a.NetworkMapCache, + nmapInitOnce: a.nmapInitOnce, } } diff --git a/management/server/types/holder.go b/management/server/types/holder.go new file mode 100644 index 000000000..3996db2b6 --- /dev/null +++ b/management/server/types/holder.go @@ -0,0 +1,43 @@ +package types + +import ( + "context" + "sync" +) + +type Holder struct { + mu sync.RWMutex + accounts map[string]*Account +} + +func NewHolder() *Holder { + return &Holder{ + accounts: make(map[string]*Account), + } +} + +func (h *Holder) GetAccount(id string) *Account { + h.mu.RLock() + defer h.mu.RUnlock() + return h.accounts[id] +} + +func (h *Holder) AddAccount(account *Account) { + h.mu.Lock() + defer h.mu.Unlock() + h.accounts[account.Id] = account +} + +func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { + h.mu.Lock() + defer h.mu.Unlock() + if acc, ok := h.accounts[id]; ok { + return acc, nil + } + account, err := accGetter(context.Background(), id) + if err != nil { + return nil, err + } + h.accounts[id] = account + return account, nil +} diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go new file mode 100644 index 000000000..c1099726f --- /dev/null +++ b/management/server/types/networkmap.go @@ -0,0 +1,58 @@ +package types + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) { + if a.NetworkMapCache != nil { + return + } + a.nmapInitOnce.Do(func() { + a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers) + }) +} + +func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} + +func (a *Account) GetPeerNetworkMapExp( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + a.initNetworkMapBuilder(validatedPeers) + return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics) +} + +func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerAddedIncremental(peerId) +} + +func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerDeleted(peerId) +} + +func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) { + if a.NetworkMapCache == nil { + return + } + a.NetworkMapCache.UpdatePeer(peer) +} + +func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go new file mode 100644 index 000000000..d85aaabb2 --- /dev/null +++ b/management/server/types/networkmap_golden_test.go @@ -0,0 +1,1069 @@ +package types_test + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "path/filepath" + "slices" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// update flag is used to update the golden file. +// example: go test ./... -v -update +// var update = flag.Bool("update", false, "update golden files") + +const ( + numPeers = 100 + devGroupID = "group-dev" + opsGroupID = "group-ops" + allGroupID = "group-all" + routeID = route.ID("route-main") + routeHA1ID = route.ID("route-ha-1") + routeHA2ID = route.ID("route-ha-2") + policyIDDevOps = "policy-dev-ops" + policyIDAll = "policy-all" + policyIDPosture = "policy-posture" + policyIDDrop = "policy-drop" + postureCheckID = "posture-check-ver" + networkResourceID = "res-database" + networkID = "net-database" + networkRouterID = "router-database" + nameserverGroupID = "ns-group-main" + testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map. + expiredPeerID = "peer-98" // This peer will be online but with an expired session. + offlinePeerID = "peer-99" // This peer will be completely offline. + routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network. + testAccountID = "account-golden-test" +) + +func TestGetPeerNetworkMap_Golden(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from OLD method does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from NEW builder does not match golden file") +} + +func BenchmarkGetPeerNetworkMap(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + b.ResetTimer() + b.Run("old builder", func(b *testing.B) { + for range b.N { + for _, peerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + b.ResetTimer() + b.Run("new builder", func(b *testing.B) { + for range b.N { + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + for _, peerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") + + t.Log("Update golden file with new peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newPeerID) + require.NoError(t, err, "error adding peer to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") + t.Log("Update golden file with OnPeerAdded...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newPeerID := "peer-new-101" + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: net.IP{100, 64, 1, 1}, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + } + + account.Peers[newPeerID] = newPeer + account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID) + account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID) + validatedPeersMap[newPeerID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") + + t.Log("Update golden file with new router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new router does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newRouterID) + require.NoError(t, err, "error adding router to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") + + t.Log("Update golden file with OnPeerAdded router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newRouterID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedPeerID := "peer-25" // peer from devs group + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedPeerID := "peer-25" // devs group peer + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedPeerID) + require.NoError(t, err, "error deleting peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") + t.Log("Update golden file with OnPeerDeleted...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerDeleted does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedRouterID) + require.NoError(t, err, "error deleting routing peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err) + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") + + t.Log("Update golden file with deleted router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err) + + require.JSONEq(t, string(expectedJSON), string(jsonData), + "network map after deleting router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + deletedPeerID := "peer-25" + + delete(account.Peers, deletedPeerID) + account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + delete(validatedPeersMap, deletedPeerID) + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + b.ResetTimer() + b.Run("old builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerDeleted(deletedPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { + for _, peer := range networkMap.Peers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + for _, peer := range networkMap.OfflinePeers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + + sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID }) + sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID }) + sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID }) + + sort.Slice(networkMap.FirewallRules, func(i, j int) bool { + r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j] + if r1.PeerIP != r2.PeerIP { + return r1.PeerIP < r2.PeerIP + } + if r1.Protocol != r2.Protocol { + return r1.Protocol < r2.Protocol + } + if r1.Direction != r2.Direction { + return r1.Direction < r2.Direction + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + return r1.Port < r2.Port + }) + + sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool { + r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j] + if r1.RouteID != r2.RouteID { + return r1.RouteID < r2.RouteID + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + if r1.Destination != r2.Destination { + return r1.Destination < r2.Destination + } + if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 { + if r1.SourceRanges[0] != r2.SourceRanges[0] { + return r1.SourceRanges[0] < r2.SourceRanges[0] + } + } + return r1.Port < r2.Port + }) + + for _, ranges := range networkMap.RoutesFirewallRules { + sort.Slice(ranges.SourceRanges, func(i, j int) bool { + return ranges.SourceRanges[i] < ranges.SourceRanges[j] + }) + } +} + +func createTestAccountWithEntities() *types.Account { + peers := make(map[string]*nbpeer.Peer) + devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, 64, 0, byte(i + 1)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if peerID == expiredPeerID { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + if i < numPeers/2 { + devGroupPeers = append(devGroupPeers, peerID) + } else { + opsGroupPeers = append(opsGroupPeers, peerID) + } + + } + + groups := map[string]*types.Group{ + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + } + + policies := []*types.Policy{ + { + ID: policyIDAll, Name: "Default-Allow", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{allGroupID}, Destinations: []string{allGroupID}, + }}, + }, + { + ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false, + PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, + SourcePostureChecks: []string{postureCheckID}, + Rules: []*types.PolicyRule{{ + ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + routeID: { + ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-75"].Key, + PeerID: "peer-75", + Description: "Route to internal resource", Enabled: true, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + }, + routeHA1ID: { + ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-80"].Key, + PeerID: "peer-80", + Description: "HA Route 1", Enabled: true, Metric: 1000, + PeerGroups: []string{allGroupID}, + Groups: []string{allGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + routeHA2ID: { + ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-90"].Key, + PeerID: "peer-90", + Description: "HA Route 2", Enabled: true, Metric: 900, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + } + + account := &types.Account{ + Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Network: &types.Network{ + Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, + NameServerGroups: map[string]*dns.NameServerGroup{ + nameserverGroupID: { + ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, + NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, + }, + Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, + }, + Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go new file mode 100644 index 000000000..58f1bfa30 --- /dev/null +++ b/management/server/types/networkmapbuilder.go @@ -0,0 +1,1932 @@ +package types + +import ( + "context" + "fmt" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" +) + +const ( + allPeers = "0.0.0.0" + fw = "fw:" + rfw = "route-fw:" + nr = "network-resource-" +) + +type NetworkMapCache struct { + globalRoutes map[route.ID]*route.Route + globalRules map[string]*FirewallRule //ruleId + globalRouteRules map[string]*RouteFirewallRule //ruleId + globalPeers map[string]*nbpeer.Peer + + groupToPeers map[string][]string + peerToGroups map[string][]string + policyToRules map[string][]*PolicyRule //policyId + groupToPolicies map[string][]*Policy + groupToRoutes map[string][]*route.Route + peerToRoutes map[string][]*route.Route + + peerACLs map[string]*PeerACLView + peerRoutes map[string]*PeerRoutesView + peerDNS map[string]*nbdns.Config + + resourceRouters map[string]map[string]*routerTypes.NetworkRouter + resourcePolicies map[string][]*Policy + + globalResources map[string]*resourceTypes.NetworkResource // resourceId + + acgToRoutes map[string]map[route.ID]*RouteOwnerInfo // routeID -> owner info + noACGRoutes map[route.ID]*RouteOwnerInfo + + mu sync.RWMutex +} + +type RouteOwnerInfo struct { + PeerID string + RouteID route.ID +} + +type PeerACLView struct { + ConnectedPeerIDs []string + FirewallRuleIDs []string +} + +type PeerRoutesView struct { + OwnRouteIDs []route.ID + NetworkResourceIDs []route.ID + InheritedRouteIDs []route.ID + RouteFirewallRuleIDs []string +} + +type NetworkMapBuilder struct { + account atomic.Pointer[Account] + cache *NetworkMapCache + validatedPeers map[string]struct{} +} + +func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder { + builder := &NetworkMapBuilder{ + cache: &NetworkMapCache{ + globalRoutes: make(map[route.ID]*route.Route), + globalRules: make(map[string]*FirewallRule), + globalRouteRules: make(map[string]*RouteFirewallRule), + globalPeers: make(map[string]*nbpeer.Peer), + groupToPeers: make(map[string][]string), + peerToGroups: make(map[string][]string), + policyToRules: make(map[string][]*PolicyRule), + groupToPolicies: make(map[string][]*Policy), + groupToRoutes: make(map[string][]*route.Route), + peerToRoutes: make(map[string][]*route.Route), + peerACLs: make(map[string]*PeerACLView), + peerRoutes: make(map[string]*PeerRoutesView), + peerDNS: make(map[string]*nbdns.Config), + globalResources: make(map[string]*resourceTypes.NetworkResource), + acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), + noACGRoutes: make(map[route.ID]*RouteOwnerInfo), + }, + validatedPeers: make(map[string]struct{}), + } + builder.account.Store(account) + maps.Copy(builder.validatedPeers, validatedPeers) + + builder.initialBuild(account) + + return builder +} + +func (b *NetworkMapBuilder) initialBuild(account *Account) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + start := time.Now() + + b.buildGlobalIndexes(account) + + resourceRouters := account.GetResourceRoutersMap() + resourcePolicies := account.GetResourcePoliciesMap() + b.cache.resourceRouters = resourceRouters + b.cache.resourcePolicies = resourcePolicies + + for peerID := range account.Peers { + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + } + + log.Debugf("NetworkMapBuilder: Initial build completed in %v for account %s", time.Since(start), account.Id) +} + +func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) { + clear(b.cache.globalPeers) + clear(b.cache.groupToPeers) + clear(b.cache.peerToGroups) + clear(b.cache.policyToRules) + clear(b.cache.groupToPolicies) + clear(b.cache.globalRoutes) + clear(b.cache.globalRules) + clear(b.cache.globalRouteRules) + clear(b.cache.globalResources) + clear(b.cache.groupToRoutes) + clear(b.cache.peerToRoutes) + clear(b.cache.acgToRoutes) + clear(b.cache.noACGRoutes) + + maps.Copy(b.cache.globalPeers, account.Peers) + + for groupID, group := range account.Groups { + peersCopy := make([]string, len(group.Peers)) + copy(peersCopy, group.Peers) + b.cache.groupToPeers[groupID] = peersCopy + + for _, peerID := range group.Peers { + b.cache.peerToGroups[peerID] = append(b.cache.peerToGroups[peerID], groupID) + } + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + b.cache.policyToRules[policy.ID] = policy.Rules + + affectedGroups := make(map[string]struct{}) + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, groupID := range rule.Sources { + affectedGroups[groupID] = struct{}{} + } + for _, groupID := range rule.Destinations { + affectedGroups[groupID] = struct{}{} + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + groupId := rule.SourceResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.SourceResource.ID] = append(b.cache.peerToGroups[rule.SourceResource.ID], groupId) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + groupId := rule.DestinationResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.DestinationResource.ID] = append(b.cache.peerToGroups[rule.DestinationResource.ID], groupId) + } + } + + for groupID := range affectedGroups { + b.cache.groupToPolicies[groupID] = append(b.cache.groupToPolicies[groupID], policy) + } + } + + for _, resource := range account.NetworkResources { + if !resource.Enabled { + continue + } + b.cache.globalResources[resource.ID] = resource + } + + for _, r := range account.Routes { + if !r.Enabled { + continue + } + for _, groupID := range r.PeerGroups { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } +} + +func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { + peer := account.GetPeer(peerID) + if peer == nil { + return + } + + allPotentialPeers, firewallRules := b.getPeerConnectionResources(account, peer, b.validatedPeers) + + isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer) + + var emptyExpiredPeers []*nbpeer.Peer + finalAllPeers := b.addNetworksRoutingPeers( + networkResourcesRoutes, + peer, + allPotentialPeers, + emptyExpiredPeers, + isRouter, + sourcePeers, + ) + + view := &PeerACLView{ + ConnectedPeerIDs: make([]string, 0, len(finalAllPeers)), + FirewallRuleIDs: make([]string, 0, len(firewallRules)), + } + + for _, p := range finalAllPeers { + view.ConnectedPeerIDs = append(view.ConnectedPeerIDs, p.ID) + } + + for _, rule := range firewallRules { + ruleID := b.generateFirewallRuleID(rule) + view.FirewallRuleIDs = append(view.FirewallRuleIDs, ruleID) + b.cache.globalRules[ruleID] = rule + } + + b.cache.peerACLs[peerID] = view +} + +func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, + validatedPeersMap map[string]struct{}, +) ([]*nbpeer.Peer, []*FirewallRule) { + ctx := context.Background() + + peerID := peer.ID + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + fwRules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + for _, group := range peerGroups { + policies := b.cache.groupToPolicies[group] + for _, policy := range policies { + if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { + continue + } + rules := b.cache.policyToRules[policy.ID] + for _, rule := range rules { + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peerInSources = rule.SourceResource.ID == peerID + } else { + peerInSources = b.isPeerInGroupscached(rule.Sources, peerGroupsMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peerInDestinations = rule.DestinationResource.ID == peerID + } else { + peerInDestinations = b.isPeerInGroupscached(rule.Destinations, peerGroupsMap) + } + + if !peerInSources && !peerInDestinations { + continue + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peer := account.GetPeer(rule.SourceResource.ID) + if peer != nil { + sourcePeers = []*nbpeer.Peer{peer} + } + } else { + sourcePeers = b.getPeersFromGroupscached(account, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peer := account.GetPeer(rule.DestinationResource.ID) + if peer != nil { + destinationPeers = []*nbpeer.Peer{peer} + } + } else { + destinationPeers = b.getPeersFromGroupscached(account, rule.Destinations, peerID, nil, validatedPeersMap) + } + + if rule.Bidirectional { + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + } + } + + return peers, fwRules +} + +func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool { + for _, groupID := range groupIDs { + if _, exists := peerGroupsMap[groupID]; exists { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs []string, + excludePeerID string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, +) []*nbpeer.Peer { + ctx := context.Background() + uniquePeers := make(map[string]*nbpeer.Peer) + + for _, groupID := range groupIDs { + peerIDs := b.cache.groupToPeers[groupID] + for _, peerID := range peerIDs { + if peerID == excludePeerID { + continue + } + + if _, ok := validatedPeersMap[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + if len(postureChecksIDs) > 0 { + if !account.validatePostureChecksOnPeer(ctx, postureChecksIDs, peerID) { + continue + } + } + + uniquePeers[peerID] = peer + } + } + + result := make([]*nbpeer.Peer, 0, len(uniquePeers)) + for _, peer := range uniquePeers { + result = append(result, peer) + } + + return result +} + +func (b *NetworkMapBuilder) generateResourcescached( + account *Account, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, + peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{}, +) { + isAll := false + if allGroup, err := account.GetGroupAll(); err == nil { + isAll = (len(allGroup.Peers) - 1) == len(groupPeers) + } + + for _, peer := range groupPeers { + if peer == nil { + continue + } + if _, ok := peersExists[peer.ID]; !ok { + *peers = append(*peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = allPeers + } + + var s strings.Builder + s.WriteString(rule.ID) + s.WriteString(fr.PeerIP) + s.WriteString(strconv.Itoa(direction)) + s.WriteString(fr.Protocol) + s.WriteString(fr.Action) + s.WriteString(strings.Join(rule.Ports, ",")) + + ruleID := s.String() + + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + *rules = append(*rules, &fr) + continue + } + + *rules = append(*rules, expandPortsAndRanges(fr, rule, targetPeer)...) + } +} + +func (b *NetworkMapBuilder) getNetworkResourcesForPeer(account *Account, peer *nbpeer.Peer) (bool, []*route.Route, map[string]struct{}) { + ctx := context.Background() + peerID := peer.ID + + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}) + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, resource := range b.cache.globalResources { + + networkRoutingPeers := b.cache.resourceRouters[resource.NetworkID] + resourcePolicies := b.cache.resourcePolicies[resource.ID] + if len(resourcePolicies) == 0 { + continue + } + + isRouterForThisResource := false + + if networkRoutingPeers != nil { + if router, ok := networkRoutingPeers[peerID]; ok && router.Enabled { + isRoutingPeer = true + isRouterForThisResource = true + if rt := b.createNetworkResourceRoutes(resource, peerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + + hasAccessAsClient := false + if !isRouterForThisResource { + for _, policy := range resourcePolicies { + if b.isPeerInGroupscached(policy.SourceGroups(), peerGroupsMap) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + hasAccessAsClient = true + break + } + } + } + } + + if hasAccessAsClient && networkRoutingPeers != nil { + for routerPeerID, router := range networkRoutingPeers { + if router.Enabled { + if rt := b.createNetworkResourceRoutes(resource, routerPeerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + } + + if isRouterForThisResource { + for _, policy := range resourcePolicies { + var peersWithAccess []*nbpeer.Peer + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peersWithAccess = []*nbpeer.Peer{peer} + } else { + peersWithAccess = b.getPeersFromGroupscached(account, policy.SourceGroups(), "", policy.SourcePostureChecks, b.validatedPeers) + } + for _, p := range peersWithAccess { + allSourcePeers[p.ID] = struct{}{} + } + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (b *NetworkMapBuilder) createNetworkResourceRoutes( + resource *resourceTypes.NetworkResource, routerPeerID string, + router *routerTypes.NetworkRouter, resourcePolicies []*Policy, +) *route.Route { + if len(resourcePolicies) > 0 { + peer := b.cache.globalPeers[routerPeerID] + if peer != nil { + return resource.ToRoute(peer, router) + } + } + return nil +} + +func (b *NetworkMapBuilder) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peer.ID) + delete(networkRoutesPeers, peer.ID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + if missingPeer := b.cache.globalPeers[p]; missingPeer != nil { + peersToConnect = append(peersToConnect, missingPeer) + } + } + + return peersToConnect +} + +func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) { + ctx := context.Background() + peer := account.GetPeer(peerID) + if peer == nil { + return + } + resourcePolicies := b.cache.resourcePolicies + + view := &PeerRoutesView{ + OwnRouteIDs: make([]route.ID, 0), + NetworkResourceIDs: make([]route.ID, 0), + RouteFirewallRuleIDs: make([]string, 0), + } + + enabledRoutes, disabledRoutes := b.getRoutingPeerRoutes(peerID) + for _, rt := range enabledRoutes { + if rt.PeerID != "" && rt.PeerID != peerID { + if b.cache.globalPeers[rt.PeerID] == nil { + continue + } + } + + view.OwnRouteIDs = append(view.OwnRouteIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + aclView := b.cache.peerACLs[peerID] + if aclView != nil { + peerRoutesMembership := make(LookupMap) + for _, r := range append(enabledRoutes, disabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(LookupMap) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, aclPeerID := range aclView.ConnectedPeerIDs { + if aclPeerID == peerID { + continue + } + activeRoutes, _ := b.getRoutingPeerRoutes(aclPeerID) + groupFilteredRoutes := account.filterRoutesByGroups(activeRoutes, peerGroupsMap) + haFilteredRoutes := account.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + + for _, inheritedRoute := range haFilteredRoutes { + view.InheritedRouteIDs = append(view.InheritedRouteIDs, inheritedRoute.ID) + b.cache.globalRoutes[inheritedRoute.ID] = inheritedRoute + } + } + } + + _, networkResourcesRoutes, _ := b.getNetworkResourcesForPeer(account, peer) + + for _, rt := range networkResourcesRoutes { + view.NetworkResourceIDs = append(view.NetworkResourceIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + allRoutes := slices.Concat(enabledRoutes, networkResourcesRoutes) + b.updateACGIndexForPeer(peerID, allRoutes) + + routeFirewallRules := b.getPeerRoutesFirewallRules(account, peerID, b.validatedPeers) + for _, rule := range routeFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + + if len(networkResourcesRoutes) > 0 { + networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies) + for _, rule := range networkResourceFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + } + + b.cache.peerRoutes[peerID] = view +} + +func (b *NetworkMapBuilder) updateACGIndexForPeer(peerID string, routes []*route.Route) { + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == peerID { + delete(b.cache.noACGRoutes, routeID) + } + } + + for _, rt := range routes { + if !rt.Enabled { + continue + } + + if len(rt.AccessControlGroups) == 0 { + b.cache.noACGRoutes[rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } else { + for _, acg := range rt.AccessControlGroups { + if b.cache.acgToRoutes[acg] == nil { + b.cache.acgToRoutes[acg] = make(map[route.ID]*RouteOwnerInfo) + } + + b.cache.acgToRoutes[acg][rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } + } + } +} + +func (b *NetworkMapBuilder) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + peer := b.cache.globalPeers[peerID] + if peer == nil { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + // maybe here is some mess - here we store peer key (see comment below) + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + peerGroups := b.cache.peerToGroups[peerID] + for _, groupID := range peerGroups { + groupRoutes := b.cache.groupToRoutes[groupID] + for _, r := range groupRoutes { + newPeerRoute := r.Copy() + // and here we store peer ID - this logic is taken from original account.getRoutingPeerRoutes + newPeerRoute.Peer = peerID + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + peerID) + takeRoute(newPeerRoute, peerID) + } + } + for _, r := range b.cache.peerToRoutes[peerID] { + takeRoute(r.Copy(), peerID) + } + return enabledRoutes, disabledRoutes +} + +func (b *NetworkMapBuilder) getPeerRoutesFirewallRules(account *Account, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + enabledRoutes, _ := b.getRoutingPeerRoutes(peerID) + for _, route := range enabledRoutes { + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := b.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := b.getAllRoutePoliciesFromGroups([]string{accessGroup}) + + rules := b.getRouteFirewallRules(peerID, policies, route, validatedPeersMap, distributionPeers, account) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (b *NetworkMapBuilder) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func (b *NetworkMapBuilder) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { + routePolicies := make(map[string]*Policy) + + for _, groupID := range accessControlGroups { + candidatePolicies := b.cache.groupToPolicies[groupID] + + for _, policy := range candidatePolicies { + if _, found := routePolicies[policy.ID]; found { + continue + } + policyRules := b.cache.policyToRules[policy.ID] + for _, rule := range policyRules { + if slices.Contains(rule.Destinations, groupID) { + routePolicies[policy.ID] = policy + break + } + } + } + } + + return maps.Values(routePolicies) +} + +func (b *NetworkMapBuilder) getRouteFirewallRules( + peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, + distributionPeers map[string]struct{}, account *Account, +) []*RouteFirewallRule { + ctx := context.Background() + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := b.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap, account) + + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (b *NetworkMapBuilder) getRulePeers( + rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, + validatedPeersMap map[string]struct{}, account *Account, +) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + + for _, id := range rule.Sources { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + + if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := b.cache.globalPeers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (b *NetworkMapBuilder) buildPeerDNSView(account *Account, peerID string) { + peerGroups := b.cache.peerToGroups[peerID] + checkGroups := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + checkGroups[groupID] = struct{}{} + } + + dnsManagementStatus := b.getPeerDNSManagementStatus(account, checkGroups) + dnsConfig := &nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + dnsConfig.NameServerGroups = b.getPeerNSGroups(account, peerID, checkGroups) + } + + b.cache.peerDNS[peerID] = dnsConfig +} + +func (b *NetworkMapBuilder) getPeerDNSManagementStatus(account *Account, checkGroups map[string]struct{}) bool { + + enabled := true + for _, groupID := range account.DNSSettings.DisabledManagementGroups { + _, found := checkGroups[groupID] + if found { + enabled = false + break + } + } + return enabled +} + +func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, checkGroups map[string]struct{}) []*nbdns.NameServerGroup { + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range account.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := checkGroups[gID] + if found { + peer := b.cache.globalPeers[peerID] + if !peerIsNameserver(peer, nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) { + b.account.Store(account) +} + +func (b *NetworkMapBuilder) GetPeerNetworkMap( + ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + account := b.account.Load() + + peer := account.GetPeer(peerID) + if peer == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + b.cache.mu.RLock() + defer b.cache.mu.RUnlock() + + aclView := b.cache.peerACLs[peerID] + routesView := b.cache.peerRoutes[peerID] + dnsConfig := b.cache.peerDNS[peerID] + + if aclView == nil || routesView == nil || dnsConfig == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers) + + if metrics != nil { + objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", + account.Id, objectCount) + } + } + + return nm +} + +func (b *NetworkMapBuilder) assembleNetworkMap( + account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, + dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, +) *NetworkMap { + + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + + for _, peerID := range aclView.ConnectedPeerIDs { + if _, ok := validatedPeers[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) + if account.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, peer) + } else { + peersToConnect = append(peersToConnect, peer) + } + } + + var routes []*route.Route + allRouteIDs := slices.Concat(routesView.OwnRouteIDs, routesView.NetworkResourceIDs, routesView.InheritedRouteIDs) + + for _, routeID := range allRouteIDs { + if route := b.cache.globalRoutes[routeID]; route != nil { + routes = append(routes, route) + } + } + + var firewallRules []*FirewallRule + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil { + firewallRules = append(firewallRules, rule) + } + } + + var routesFirewallRules []*RouteFirewallRule + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + + finalDNSConfig := *dnsConfig + if finalDNSConfig.ServiceEnable && customZone.Domain != "" { + var zones []nbdns.CustomZone + records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers) + zones = append(zones, nbdns.CustomZone{ + Domain: customZone.Domain, + Records: records, + }) + finalDNSConfig.CustomZones = zones + } + + return &NetworkMap{ + Peers: peersToConnect, + Network: account.Network.Copy(), + Routes: routes, + DNSConfig: finalDNSConfig, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, + } +} + +func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { + var s strings.Builder + s.WriteString(fw) + s.WriteString(rule.PolicyID) + s.WriteRune(':') + s.WriteString(rule.PeerIP) + s.WriteRune(':') + s.WriteString(strconv.Itoa(rule.Direction)) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(rule.Port) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.PortRange.Start))) + s.WriteRune('-') + s.WriteString(strconv.Itoa(int(rule.PortRange.End))) + return s.String() +} + +func (b *NetworkMapBuilder) generateRouteFirewallRuleID(rule *RouteFirewallRule) string { + var s strings.Builder + s.WriteString(rfw) + s.WriteString(string(rule.RouteID)) + s.WriteRune(':') + s.WriteString(rule.Destination) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(strings.Join(rule.SourceRanges, ",")) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.Port))) + return s.String() +} + +func (b *NetworkMapBuilder) isPeerInGroups(groupIDs []string, peerGroups []string) bool { + for _, groupID := range groupIDs { + if slices.Contains(peerGroups, groupID) { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool { + for _, r := range account.Routes { + if !r.Enabled { + continue + } + + if r.PeerID == peerID { + return true + } + + if peer := b.cache.globalPeers[peerID]; peer != nil { + if r.Peer == peer.Key && r.PeerID == "" { + return true + } + } + } + + routers := account.GetResourceRoutersMap() + for _, networkRouters := range routers { + if router, exists := networkRouters[peerID]; exists && router.Enabled { + return true + } + } + + return false +} + +type ViewDelta struct { + AddedPeerIDs []string + RemovedPeerIDs []string + AddedRuleIDs []string + RemovedRuleIDs []string +} + +func (b *NetworkMapBuilder) OnPeerAddedIncremental(peerID string) error { + tt := time.Now() + account := b.account.Load() + peer := account.GetPeer(peerID) + if peer == nil { + return fmt.Errorf("peer %s not found in account", peerID) + } + + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String()) + + b.validatedPeers[peerID] = struct{}{} + + b.cache.globalPeers[peerID] = peer + + peerGroups := b.updateIndexesForNewPeer(account, peerID) + + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + + log.Debugf("NetworkMapBuilder: Adding peer %s to cache, views took %s", peerID, time.Since(tt)) + + b.incrementalUpdateAffectedPeers(account, peerID, peerGroups) + + log.Debugf("NetworkMapBuilder: Added peer %s to cache, took %s", peerID, time.Since(tt)) + + return nil +} + +func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID string) []string { + peerGroups := make([]string, 0) + + for groupID, group := range account.Groups { + if slices.Contains(group.Peers, peerID) { + if !slices.Contains(b.cache.groupToPeers[groupID], peerID) { + b.cache.groupToPeers[groupID] = append(b.cache.groupToPeers[groupID], peerID) + } + peerGroups = append(peerGroups, groupID) + } + } + + b.cache.peerToGroups[peerID] = peerGroups + + for _, r := range account.Routes { + if !r.Enabled || b.cache.globalRoutes[r.ID] != nil { + continue + } + for _, groupID := range r.PeerGroups { + if !slices.Contains(b.cache.groupToRoutes[groupID], r) { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + if !slices.Contains(b.cache.peerToRoutes[peer.ID], r) { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } + b.cache.globalRoutes[r.ID] = r + } + + return peerGroups +} + +func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) { + updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups) + + if b.isPeerRouter(account, newPeerID) { + affectedByRoutes := b.findPeersAffectedByNewRouter(account, newPeerID, peerGroups) + for affectedPeerID := range affectedByRoutes { + if affectedPeerID == newPeerID { + continue + } + if _, exists := updates[affectedPeerID]; !exists { + updates[affectedPeerID] = &PeerUpdateDelta{ + PeerID: affectedPeerID, + RebuildRoutesView: true, + } + } else { + updates[affectedPeerID].RebuildRoutesView = true + } + } + } + + for affectedPeerID, delta := range updates { + b.applyDeltaToPeer(account, affectedPeerID, delta) + } +} + +func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} { + affected := make(map[string]struct{}) + enabledRoutes, _ := b.getRoutingPeerRoutes(newRouterID) + + for _, route := range enabledRoutes { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + + for _, peerGroupID := range route.PeerGroups { + if peers := b.cache.groupToPeers[peerGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + } + + for _, route := range account.Routes { + if !route.Enabled { + continue + } + + routerInPeerGroups := false + for _, peerGroupID := range route.PeerGroups { + if slices.Contains(routerGroups, peerGroupID) { + routerInPeerGroups = true + break + } + } + + if routerInPeerGroups { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + affected[peerID] = struct{}{} + } + } + } + } + } + + return affected +} + +func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { + updates := make(map[string]*PeerUpdateDelta) + ctx := context.Background() + + groupAllLn := 0 + if allGroup, err := account.GetGroupAll(); err == nil { + groupAllLn = len(allGroup.Peers) - 1 + } + + newPeer := b.cache.globalPeers[newPeerID] + if newPeer == nil { + return updates + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + peerInSources := b.isPeerInGroups(rule.Sources, peerGroups) + peerInDestinations := b.isPeerInGroups(rule.Destinations, peerGroups) + + if peerInSources { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + + if peerInDestinations { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + + if rule.Bidirectional { + if peerInSources { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if peerInDestinations { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + } + } + } + + b.calculateRouteFirewallUpdates(newPeerID, newPeer, peerGroups, updates) + + b.calculateNetworkResourceFirewallUpdates(ctx, account, newPeerID, newPeer, peerGroups, updates) + + b.calculateNewRouterNetworkResourceUpdates(ctx, account, newPeerID, updates) + + return updates +} + +func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates( + ctx context.Context, account *Account, newPeerID string, + updates map[string]*PeerUpdateDelta, +) { + resourceRouters := b.cache.resourceRouters + + for networkID, routers := range resourceRouters { + router, isRouter := routers[newPeerID] + if !isRouter || !router.Enabled { + continue + } + + for _, resource := range b.cache.globalResources { + if resource.NetworkID != networkID { + continue + } + + policies := b.cache.resourcePolicies[resource.ID] + if len(policies) == 0 { + continue + } + + peersWithAccess := make(map[string]struct{}) + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + groupPeers := b.cache.groupToPeers[sourceGroup] + for _, peerID := range groupPeers { + if peerID == newPeerID { + continue + } + + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + peersWithAccess[peerID] = struct{}{} + } + } + } + } + + for peerID := range peersWithAccess { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + } + updates[peerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } + } +} + +func (b *NetworkMapBuilder) calculateRouteFirewallUpdates( + newPeerID string, newPeer *nbpeer.Peer, + peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + processedPeerRoutes := make(map[string]map[route.ID]struct{}) + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == newPeerID { + continue + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + + for _, acg := range peerGroups { + routeInfos := b.cache.acgToRoutes[acg] + if routeInfos == nil { + continue + } + + for routeID, info := range routeInfos { + if info.PeerID == newPeerID { + continue + } + + if processedRoutes, exists := processedPeerRoutes[info.PeerID]; exists { + if _, processed := processedRoutes[routeID]; processed { + continue + } + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + } +} + +func (b *NetworkMapBuilder) addRouteFirewallUpdate( + updates map[string]*PeerUpdateDelta, peerID string, + routeID string, sourceIP string, +) { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + UpdateRouteFirewallRules: make([]*RouteFirewallRuleUpdate, 0), + } + updates[peerID] = delta + } + + for _, existing := range delta.UpdateRouteFirewallRules { + if existing.RuleID == routeID && existing.AddSourceIP == sourceIP { + return + } + } + + delta.UpdateRouteFirewallRules = append(delta.UpdateRouteFirewallRules, &RouteFirewallRuleUpdate{ + RuleID: routeID, + AddSourceIP: sourceIP, + }) +} + +func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( + ctx context.Context, account *Account, newPeerID string, + newPeer *nbpeer.Peer, peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + for _, resource := range b.cache.globalResources { + resourcePolicies := b.cache.resourcePolicies + resourceRouters := b.cache.resourceRouters + + policies := resourcePolicies[resource.ID] + peerHasAccess := false + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + if slices.Contains(peerGroups, sourceGroup) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, newPeerID) { + peerHasAccess = true + break + } + } + } + + if peerHasAccess { + break + } + } + + if !peerHasAccess { + continue + } + + networkRouters := resourceRouters[resource.NetworkID] + for routerPeerID, router := range networkRouters { + if !router.Enabled || routerPeerID == newPeerID { + continue + } + + delta := updates[routerPeerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: routerPeerID, + } + updates[routerPeerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } +} + +type PeerUpdateDelta struct { + PeerID string + AddConnectedPeer string + AddFirewallRules []*FirewallRuleDelta + AddRoutes []route.ID + UpdateRouteFirewallRules []*RouteFirewallRuleUpdate + UpdateDNS bool + RebuildRoutesView bool +} +type FirewallRuleDelta struct { + Rule *FirewallRule + RuleID string + Direction int +} + +type RouteFirewallRuleUpdate struct { + RuleID string + AddSourceIP string +} + +func (b *NetworkMapBuilder) addUpdateForPeersInGroups( + updates map[string]*PeerUpdateDelta, groupIDs []string, newPeerID string, + rule *PolicyRule, direction int, allGroupLn int, +) { + for _, groupID := range groupIDs { + peers := b.cache.groupToPeers[groupID] + cnt := 0 + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + cnt++ + } + all := false + if allGroupLn > 0 && cnt == allGroupLn { + all = true + } + newPeer := b.cache.globalPeers[newPeerID] + fr := &FirewallRule{ + PolicyID: rule.ID, + PeerIP: newPeer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + AddConnectedPeer: newPeerID, + AddFirewallRules: make([]*FirewallRuleDelta, 0), + } + updates[peerID] = delta + } + + if all { + fr.PeerIP = allPeers + } + + if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { + expandedRules := expandPortsAndRanges(*fr, rule, b.cache.globalPeers[peerID]) + for _, expandedRule := range expandedRules { + ruleID := b.generateFirewallRuleID(expandedRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: expandedRule, + RuleID: ruleID, + Direction: direction, + }) + } + } else { + ruleID := b.generateFirewallRuleID(fr) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: fr, + RuleID: ruleID, + Direction: direction, + }) + } + } + } +} + +func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { + if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + if delta.AddConnectedPeer != "" && !slices.Contains(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) { + aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) + } + + for _, ruleDelta := range delta.AddFirewallRules { + b.cache.globalRules[ruleDelta.RuleID] = ruleDelta.Rule + + if !slices.Contains(aclView.FirewallRuleIDs, ruleDelta.RuleID) { + aclView.FirewallRuleIDs = append(aclView.FirewallRuleIDs, ruleDelta.RuleID) + } + } + } + } + + if delta.RebuildRoutesView { + b.buildPeerRoutesView(account, peerID) + } else if len(delta.UpdateRouteFirewallRules) > 0 { + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + b.updateRouteFirewallRules(routesView, delta.UpdateRouteFirewallRules) + } + } + + if delta.UpdateDNS { + b.buildPeerDNSView(account, peerID) + } +} + +func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, updates []*RouteFirewallRuleUpdate) { + for _, update := range updates { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + rule := b.cache.globalRouteRules[ruleID] + if rule == nil { + continue + } + + if string(rule.RouteID) == update.RuleID { + sourceIP := update.AddSourceIP + + if strings.Contains(sourceIP, ":") { + sourceIP += "/128" // IPv6 + } else { + sourceIP += "/32" // IPv4 + } + + if !slices.Contains(rule.SourceRanges, sourceIP) { + rule.SourceRanges = append(rule.SourceRanges, sourceIP) + } + break + } + } + } +} + +func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + account := b.account.Load() + + deletedPeer := b.cache.globalPeers[peerID] + if deletedPeer == nil { + return fmt.Errorf("peer %s not found in cache", peerID) + } + + deletedPeerKey := deletedPeer.Key + peerGroups := b.cache.peerToGroups[peerID] + peerIP := deletedPeer.IP.String() + + log.Debugf("NetworkMapBuilder: Deleting peer %s (IP: %s) from cache", peerID, peerIP) + + delete(b.validatedPeers, peerID) + + routesToDelete := []route.ID{} + + for routeID, r := range account.Routes { + if r.Peer != deletedPeerKey && r.PeerID != peerID { + continue + } + if len(r.PeerGroups) == 0 { + routesToDelete = append(routesToDelete, routeID) + continue + } + newPeerAssigned := false + for _, groupID := range r.PeerGroups { + candidatePeerIDs := b.cache.groupToPeers[groupID] + for _, candidatePeerID := range candidatePeerIDs { + if candidatePeerID == peerID { + continue + } + if candidatePeer := b.cache.globalPeers[candidatePeerID]; candidatePeer != nil { + r.Peer = candidatePeer.Key + r.PeerID = candidatePeerID + newPeerAssigned = true + break + } + } + if newPeerAssigned { + break + } + } + + if !newPeerAssigned { + routesToDelete = append(routesToDelete, routeID) + } + } + + for _, routeID := range routesToDelete { + delete(account.Routes, routeID) + } + + delete(b.cache.peerACLs, peerID) + delete(b.cache.peerRoutes, peerID) + delete(b.cache.peerDNS, peerID) + + delete(b.cache.globalPeers, peerID) + + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for _, groupID := range peerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + b.cache.groupToPeers[groupID] = slices.DeleteFunc(peers, func(id string) bool { + return id == peerID + }) + } + } + delete(b.cache.peerToGroups, peerID) + + affectedPeers := make(map[string]struct{}) + + for _, r := range account.Routes { + for _, groupID := range r.Groups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + + for _, groupID := range r.PeerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + } + + for affectedPeerID := range affectedPeers { + if affectedPeerID == peerID { + continue + } + b.buildPeerRoutesView(account, affectedPeerID) + } + + peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP) + for affectedPeerID, updates := range peerDeletionUpdates { + b.applyDeletionUpdates(affectedPeerID, updates) + } + + b.cleanupUnusedRules() + + log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers)) + + return nil +} + +func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( + deletedPeerID string, + peerIP string, +) map[string]*PeerDeletionUpdate { + + affected := make(map[string]*PeerDeletionUpdate) + + for peerID, aclView := range b.cache.peerACLs { + if peerID == deletedPeerID { + continue + } + + if !slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { + continue + } + if affected[peerID] == nil { + affected[peerID] = &PeerDeletionUpdate{ + RemovePeerID: deletedPeerID, + PeerIP: peerIP, + } + } + + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil && rule.PeerIP == peerIP { + affected[peerID].RemoveFirewallRuleIDs = append( + affected[peerID].RemoveFirewallRuleIDs, + ruleID, + ) + } + } + } + + return affected +} + +type PeerDeletionUpdate struct { + RemovePeerID string + RemoveFirewallRuleIDs []string + RemoveRouteIDs []route.ID + RemoveFromSourceRanges bool + PeerIP string +} + +func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + aclView.ConnectedPeerIDs = slices.DeleteFunc(aclView.ConnectedPeerIDs, func(id string) bool { + return id == updates.RemovePeerID + }) + + if len(updates.RemoveFirewallRuleIDs) > 0 { + aclView.FirewallRuleIDs = slices.DeleteFunc(aclView.FirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(updates.RemoveFirewallRuleIDs, ruleID) + }) + } + } + + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + if len(updates.RemoveRouteIDs) > 0 { + routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool { + return slices.Contains(updates.RemoveRouteIDs, routeID) + }) + } + + if updates.RemoveFromSourceRanges { + b.removeIPFromRouteFirewallRules(routesView, updates.PeerIP) + } + } +} + +func (b *NetworkMapBuilder) removeIPFromRouteFirewallRules(routesView *PeerRoutesView, peerIP string) { + sourceIPv4 := peerIP + "/32" + sourceIPv6 := peerIP + "/128" + + rulesToRemove := []string{} + + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + rule.SourceRanges = slices.DeleteFunc(rule.SourceRanges, func(source string) bool { + return source == sourceIPv4 || source == sourceIPv6 || source == peerIP + }) + + if len(rule.SourceRanges) == 0 { + rulesToRemove = append(rulesToRemove, ruleID) + } + } + } + + if len(rulesToRemove) > 0 { + routesView.RouteFirewallRuleIDs = slices.DeleteFunc(routesView.RouteFirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(rulesToRemove, ruleID) + }) + } +} + +func (b *NetworkMapBuilder) cleanupUnusedRules() { + usedFirewallRules := make(map[string]struct{}) + usedRouteRules := make(map[string]struct{}) + usedRoutes := make(map[route.ID]struct{}) + + for _, aclView := range b.cache.peerACLs { + for _, ruleID := range aclView.FirewallRuleIDs { + usedFirewallRules[ruleID] = struct{}{} + } + } + + for _, routesView := range b.cache.peerRoutes { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + usedRouteRules[ruleID] = struct{}{} + } + + for _, routeID := range routesView.OwnRouteIDs { + usedRoutes[routeID] = struct{}{} + } + for _, routeID := range routesView.NetworkResourceIDs { + usedRoutes[routeID] = struct{}{} + } + } + + for ruleID := range b.cache.globalRules { + if _, used := usedFirewallRules[ruleID]; !used { + delete(b.cache.globalRules, ruleID) + } + } + + for ruleID := range b.cache.globalRouteRules { + if _, used := usedRouteRules[ruleID]; !used { + delete(b.cache.globalRouteRules, ruleID) + } + } + + for routeID := range b.cache.globalRoutes { + if _, used := usedRoutes[routeID]; !used { + delete(b.cache.globalRoutes, routeID) + } + } +} + +func (b *NetworkMapBuilder) UpdatePeer(peer *nbpeer.Peer) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + peerStored, ok := b.cache.globalPeers[peer.ID] + if !ok { + return + } + *peerStored = *peer +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index da12f1b70..adf64592a 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -7,16 +7,14 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) const channelBufferSize = 100 type UpdateMessage struct { - Update *proto.SyncResponse - NetworkMap *types.NetworkMap + Update *proto.SyncResponse } type PeersUpdateManager struct { diff --git a/management/server/user.go b/management/server/user.go index 25c87df9c..66bea314f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -991,6 +991,10 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) + + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } } if len(peerIDs) != 0 { diff --git a/route/route.go b/route/route.go index 08a2d37dc..c724e7c7d 100644 --- a/route/route.go +++ b/route/route.go @@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any { func (r *Route) Copy() *Route { route := &Route{ ID: r.ID, + AccountID: r.AccountID, Description: r.Description, NetID: r.NetID, Network: r.Network, From 48475ddc058f71fccba80d622178d3f5abab79f0 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:50:18 +0100 Subject: [PATCH 063/120] [management] add pat rate limiting (#4741) --- go.mod | 2 +- management/server/http/handler.go | 42 ++- .../server/http/middleware/auth_middleware.go | 20 +- .../http/middleware/auth_middleware_test.go | 285 +++++++++++++++++- .../server/http/middleware/rate_limiter.go | 146 +++++++++ shared/management/http/util/util.go | 2 + shared/management/status/error.go | 3 + 7 files changed, 496 insertions(+), 4 deletions(-) create mode 100644 management/server/http/middleware/rate_limiter.go diff --git a/go.mod b/go.mod index 68a12908d..7b9bae321 100644 --- a/go.mod +++ b/go.mod @@ -108,6 +108,7 @@ require ( golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 golang.org/x/term v0.33.0 + golang.org/x/time v0.12.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -245,7 +246,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect - golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.34.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3d4de31d0..4d2c224b4 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -4,9 +4,13 @@ import ( "context" "fmt" "net/http" + "os" + "strconv" + "time" "github.com/gorilla/mux" "github.com/rs/cors" + log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" @@ -38,7 +42,12 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) -const apiPrefix = "/api" +const ( + apiPrefix = "/api" + rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED" + rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST" + rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" +) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. func NewAPIHandler( @@ -58,11 +67,42 @@ func NewAPIHandler( settingsManager settings.Manager, ) (http.Handler, error) { + var rateLimitingConfig *middleware.RateLimiterConfig + if os.Getenv(rateLimitingEnabledKey) == "true" { + rpm := 6 + if v := os.Getenv(rateLimitingRPMKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm) + } else { + rpm = value + } + } + + burst := 500 + if v := os.Getenv(rateLimitingBurstKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst) + } else { + burst = value + } + } + + rateLimitingConfig = &middleware.RateLimiterConfig{ + RequestsPerMinute: float64(rpm), + Burst: burst, + CleanupInterval: 6 * time.Hour, + LimiterTTL: 24 * time.Hour, + } + } + authMiddleware := middleware.NewAuthMiddleware( authManager, accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, + rateLimitingConfig, ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6091a4c31..bce917a25 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -29,6 +29,7 @@ type AuthMiddleware struct { ensureAccount EnsureAccountFunc getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc + rateLimiter *APIRateLimiter } // NewAuthMiddleware instance constructor @@ -37,12 +38,19 @@ func NewAuthMiddleware( ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, + rateLimiterConfig *RateLimiterConfig, ) *AuthMiddleware { + var rateLimiter *APIRateLimiter + if rateLimiterConfig != nil { + rateLimiter = NewAPIRateLimiter(rateLimiterConfig) + } + return &AuthMiddleware{ authManager: authManager, ensureAccount: ensureAccount, syncUserJWTGroups: syncUserJWTGroups, getUserFromUserAuth: getUserFromUserAuth, + rateLimiter: rateLimiter, } } @@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { request, err := m.checkPATFromRequest(r, auth) if err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) - util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) + // Check if it's a status error, otherwise default to Unauthorized + if _, ok := status.FromError(err); !ok { + err = status.Errorf(status.Unauthorized, "token invalid") + } + util.WriteError(r.Context(), err, w) return } h.ServeHTTP(w, request) @@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h return r, fmt.Errorf("error extracting token: %w", err) } + if m.rateLimiter != nil { + if !m.rateLimiter.Allow(token) { + return r, status.Errorf(status.TooManyRequests, "too many requests") + } + } + ctx := r.Context() user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index d815f5422..d1bd9959f 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -27,7 +27,9 @@ const ( domainCategory = "domainCategory" userID = "userID" tokenID = "tokenID" + tokenID2 = "tokenID2" PAT = "nbp_PAT" + PAT2 = "nbp_PAT2" JWT = "JWT" wrongToken = "wrongToken" ) @@ -49,6 +51,15 @@ var testAccount = &types.Account{ CreatedAt: time.Now().UTC(), LastUsed: util.ToPtr(time.Now().UTC()), }, + tokenID2: { + ID: tokenID2, + Name: "My second token", + HashedToken: "someHash2", + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), + CreatedBy: userID, + CreatedAt: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), + }, }, }, }, @@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use if token == PAT { return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil } + if token == PAT2 { + return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil + } return nil, nil, "", "", fmt.Errorf("PAT invalid") } @@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA } func mockMarkPATUsed(_ context.Context, token string) error { - if token == tokenID { + if token == tokenID || token == tokenID2 { return nil } return fmt.Errorf("Should never get reached") @@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) { } } +func TestAuthMiddleware_RateLimiting(t *testing.T) { + mockAuth := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups, + MarkPATUsedFunc: mockMarkPATUsed, + GetPATInfoFunc: mockGetAccountInfoFromPAT, + } + + t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) { + // Configure rate limiter: 10 requests per minute with burst of 5 + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 10, + Burst: 5, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make burst requests - all should succeed + successCount := 0 + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 5, successCount, "All burst requests should succeed") + + // The 6th request should fail (exceeded burst) + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited") + }) + + t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) { + // Configure very low rate limit: 1 request per minute + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request should fail (rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + }) + + t.Run("Bearer Token Not Rate Limited", func(t *testing.T) { + // Configure strict rate limit + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make multiple requests with Bearer token - all should succeed + successCount := 0 + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)") + }) + + t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) { + // Configure rate limiter + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Use first PAT token + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed") + + // Second request with same token should fail + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited") + + // Use second PAT token - should succeed because it has independent rate limit + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)") + + // Second request with PAT2 should also be rate limited + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited") + + // JWT should still work (not rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)") + }) + + t.Run("Rate Limiter Cleanup", func(t *testing.T) { + // Configure rate limiter with short cleanup interval and TTL for testing + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: 100 * time.Millisecond, + LimiterTTL: 200 * time.Millisecond, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request - should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request immediately - should fail (burst exhausted) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + + // Wait for limiter to be cleaned up (TTL + cleanup interval + buffer) + time.Sleep(400 * time.Millisecond) + + // After cleanup, the limiter should be removed and recreated with full burst capacity + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)") + + // Verify it's a fresh limiter by checking burst is reset + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again") + }) +} + func TestAuthMiddleware_Handler_Child(t *testing.T) { tt := []struct { name string @@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, ) for _, tc := range tt { diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go new file mode 100644 index 000000000..a6266d4f3 --- /dev/null +++ b/management/server/http/middleware/rate_limiter.go @@ -0,0 +1,146 @@ +package middleware + +import ( + "context" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// RateLimiterConfig holds configuration for the API rate limiter +type RateLimiterConfig struct { + // RequestsPerMinute defines the rate at which tokens are replenished + RequestsPerMinute float64 + // Burst defines the maximum number of requests that can be made in a burst + Burst int + // CleanupInterval defines how often to clean up old limiters (how often garbage collection runs) + CleanupInterval time.Duration + // LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal) + LimiterTTL time.Duration +} + +// DefaultRateLimiterConfig returns a default configuration +func DefaultRateLimiterConfig() *RateLimiterConfig { + return &RateLimiterConfig{ + RequestsPerMinute: 100, + Burst: 120, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } +} + +// limiterEntry holds a rate limiter and its last access time +type limiterEntry struct { + limiter *rate.Limiter + lastAccess time.Time +} + +// APIRateLimiter manages rate limiting for API tokens +type APIRateLimiter struct { + config *RateLimiterConfig + limiters map[string]*limiterEntry + mu sync.RWMutex + stopChan chan struct{} +} + +// NewAPIRateLimiter creates a new API rate limiter with the given configuration +func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter { + if config == nil { + config = DefaultRateLimiterConfig() + } + + rl := &APIRateLimiter{ + config: config, + limiters: make(map[string]*limiterEntry), + stopChan: make(chan struct{}), + } + + go rl.cleanupLoop() + + return rl +} + +// Allow checks if a request for the given key (token) is allowed +func (rl *APIRateLimiter) Allow(key string) bool { + limiter := rl.getLimiter(key) + return limiter.Allow() +} + +// Wait blocks until the rate limiter allows another request for the given key +// Returns an error if the context is canceled +func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { + limiter := rl.getLimiter(key) + return limiter.Wait(ctx) +} + +// getLimiter retrieves or creates a rate limiter for the given key +func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter { + rl.mu.RLock() + entry, exists := rl.limiters[key] + rl.mu.RUnlock() + + if exists { + rl.mu.Lock() + entry.lastAccess = time.Now() + rl.mu.Unlock() + return entry.limiter + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + if entry, exists := rl.limiters[key]; exists { + entry.lastAccess = time.Now() + return entry.limiter + } + + requestsPerSecond := rl.config.RequestsPerMinute / 60.0 + limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst) + rl.limiters[key] = &limiterEntry{ + limiter: limiter, + lastAccess: time.Now(), + } + + return limiter +} + +// cleanupLoop periodically removes old limiters that haven't been used recently +func (rl *APIRateLimiter) cleanupLoop() { + ticker := time.NewTicker(rl.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + rl.cleanup() + case <-rl.stopChan: + return + } + } +} + +// cleanup removes limiters that haven't been used within the TTL period +func (rl *APIRateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for key, entry := range rl.limiters { + if now.Sub(entry.lastAccess) > rl.config.LimiterTTL { + delete(rl.limiters, key) + } + } +} + +// Stop stops the cleanup goroutine +func (rl *APIRateLimiter) Stop() { + close(rl.stopChan) +} + +// Reset removes the rate limiter for a specific key +func (rl *APIRateLimiter) Reset(key string) { + rl.mu.Lock() + defer rl.mu.Unlock() + delete(rl.limiters, key) +} diff --git a/shared/management/http/util/util.go b/shared/management/http/util/util.go index 3ae321023..0a29469da 100644 --- a/shared/management/http/util/util.go +++ b/shared/management/http/util/util.go @@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) { httpStatus = http.StatusUnauthorized case status.BadRequest: httpStatus = http.StatusBadRequest + case status.TooManyRequests: + httpStatus = http.StatusTooManyRequests default: } msg = strings.ToLower(err.Error()) diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 1e914babb..09676847e 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -37,6 +37,9 @@ const ( // Unauthenticated indicates that user is not authenticated due to absence of valid credentials Unauthenticated Type = 10 + + // TooManyRequests indicates that the user has sent too many requests in a given amount of time (rate limiting) + TooManyRequests Type = 11 ) // Type is a type of the Error From 98ddac07bfb4159a6eeeb2178604116290147a4d Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:50:58 +0100 Subject: [PATCH 064/120] [management] remove toAll firewall rule (#4725) --- management/server/policy_test.go | 152 ++++++++++++++++++++++++++++- management/server/types/account.go | 11 --- 2 files changed, 148 insertions(+), 15 deletions(-) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 4a08f4c33..97ebbcf5a 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { expectedFirewallRules := []*types.FirewallRule{ { - PeerIP: "0.0.0.0", + PeerIP: "100.65.14.88", Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", @@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) { PolicyID: "RuleDefault", }, { - PeerIP: "0.0.0.0", + PeerIP: "100.65.14.88", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.254.139", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.254.139", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.250.202", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.250.202", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.29.55", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.29.55", Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", @@ -833,10 +929,58 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // We expect a single permissive firewall rule which all outgoing connections peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) - assert.Len(t, firewallRules, 1) + assert.Len(t, firewallRules, 7) expectedFirewallRules := []*types.FirewallRule{ { - PeerIP: "0.0.0.0", + PeerIP: "100.65.80.39", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.14.88", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.29.55", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.21.56", Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", diff --git a/management/server/types/account.go b/management/server/types/account.go index dd6052498..3818a84ce 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1062,14 +1062,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer rules := make([]*FirewallRule, 0) peers := make([]*nbpeer.Peer, 0) - all, err := a.GetGroupAll() - if err != nil { - log.WithContext(ctx).Errorf("failed to get group all: %v", err) - all = &Group{} - } - return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { - isAll := (len(all.Peers) - 1) == len(groupPeers) for _, peer := range groupPeers { if peer == nil { continue @@ -1088,10 +1081,6 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer Protocol: string(rule.Protocol), } - if isAll { - fr.PeerIP = "0.0.0.0" - } - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") if _, ok := rulesExists[ruleID]; ok { From dbfc8a52c932a4a246c07938eeabfaef78f66c3b Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 7 Nov 2025 16:03:14 +0100 Subject: [PATCH 065/120] [management] remove GLOBAL when disabling foreign keys on mysql (#4615) --- management/server/store/sql_store.go | 60 +++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index d83d160c3..3146254a3 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -85,12 +85,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met conns = runtime.NumCPU() } - switch storeEngine { - case types.MysqlStoreEngine: - if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil { - return nil, err - } - case types.SqliteStoreEngine: + if storeEngine == types.SqliteStoreEngine { if err == nil { log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") } @@ -175,7 +170,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro group.StoreGroupPeers() } - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { return result.Error @@ -270,7 +265,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error { start := time.Now() - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { return result.Error @@ -324,7 +319,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string result := tx.Model(&nbpeer.Peer{}).Select("id").Take(&peerID, accountAndIDQueryCondition, accountID, peer.ID) @@ -613,7 +608,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error { - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Delete(&types.PersonalAccessToken{}, "user_id = ?", userID) if result.Error != nil { return result.Error @@ -2932,6 +2927,16 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor if tx.Error != nil { return tx.Error } + + // For MySQL, disable FK checks within this transaction to avoid deadlocks + // This is session-scoped and doesn't require SUPER privileges + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to disable FK checks: %w", err) + } + } + repo := s.withTx(tx) err := operation(repo) if err != nil { @@ -2939,6 +2944,14 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor return err } + // Re-enable FK checks before commit (optional, as transaction end resets it) + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to re-enable FK checks: %w", err) + } + } + err = tx.Commit().Error log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime)) @@ -2956,6 +2969,31 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { } } +// transaction wraps a GORM transaction with MySQL-specific FK checks handling +// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora +func (s *SqlStore) transaction(fn func(*gorm.DB) error) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // For MySQL, disable FK checks within this transaction to avoid deadlocks + // This is session-scoped and doesn't require SUPER privileges + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { + return fmt.Errorf("failed to disable FK checks: %w", err) + } + } + + err := fn(tx) + + // Re-enable FK checks before commit (optional, as transaction end resets it) + if s.storeEngine == types.MysqlStoreEngine && err == nil { + if fkErr := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; fkErr != nil { + return fmt.Errorf("failed to re-enable FK checks: %w", fkErr) + } + } + + return err + }) +} + func (s *SqlStore) GetDB() *gorm.DB { return s.db } @@ -3212,7 +3250,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { } func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.transaction(func(tx *gorm.DB) error { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { return fmt.Errorf("delete policy rules: %w", err) } From 7df49e249dc035ece05d2514144ddac67674201f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 7 Nov 2025 20:14:52 +0100 Subject: [PATCH 066/120] [management ] remove timing logs (#4761) --- management/server/account.go | 5 ----- management/server/grpcserver.go | 10 --------- management/server/peer.go | 33 ---------------------------- management/server/store/sql_store.go | 13 ----------- management/server/types/account.go | 4 ---- 5 files changed, 65 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 0aecbd586..f5a5c7b7a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1672,11 +1672,6 @@ func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) boo } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start)) - }() - peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 0a5236cb3..d3d94443a 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -775,11 +775,6 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("toSyncResponse: took %s", time.Since(start)) - }() - response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), NetworkMap: &proto.NetworkMap{ @@ -844,11 +839,6 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("sendInitialSync: took %s", time.Since(start)) - }() - var err error var turnToken *Token diff --git a/management/server/peer.go b/management/server/peer.go index 80ab7fc69..4c605b5eb 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -106,11 +106,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start)) - }() - var peer *nbpeer.Peer var settings *types.Settings var expired bool @@ -744,11 +739,6 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start)) - }() - var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool @@ -822,10 +812,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("handlePeerNotFound: took %s", time.Since(start)) - }() if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. @@ -847,11 +833,6 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("LoginPeer: took %s", time.Since(start)) - }() - accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return am.handlePeerLoginNotFound(ctx, login, err) @@ -965,11 +946,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer // getPeerPostureChecks returns the posture checks for the peer. func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getPostureChecks: took %s", time.Since(start)) - }() - policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err @@ -1058,11 +1034,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co } func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start)) - }() - if isRequiresApproval { network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -1611,10 +1582,6 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p // getPeerGroupIDs returns the IDs of the groups that the peer is part of. func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getPeerGroupIDs: took %s", time.Since(start)) - }() return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 3146254a3..75f2c3ae7 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -311,10 +311,6 @@ func (s *SqlStore) GetInstallationID() string { } func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SavePeer: took %s", time.Since(start)) - }() // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID @@ -2156,10 +2152,6 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("GetAccountNetwork: took %s", time.Since(start)) - }() ctx, cancel := getDebuggingCtx(ctx) defer cancel() @@ -2201,11 +2193,6 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getAccountSettings: took %s", time.Since(start)) - }() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) diff --git a/management/server/types/account.go b/management/server/types/account.go index 3818a84ce..8797e1fa3 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -265,10 +265,6 @@ func (a *Account) GetPeerNetworkMap( metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("GetPeerNetworkMap: took %s", time.Since(start)) - }() - peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ From 07cf9d5895404b1a4e04fd8de17918d5c0edf8ce Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 8 Nov 2025 10:54:37 +0100 Subject: [PATCH 067/120] [client] Create networkd.conf.d if it doesn't exist (#4764) --- client/cmd/service_installer.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 2a87e538d..f6828d96a 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -259,6 +259,7 @@ func isServiceRunning() (bool, error) { } const ( + networkdConf = "/etc/systemd/networkd.conf" networkdConfDir = "/etc/systemd/networkd.conf.d" networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf" networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing @@ -273,12 +274,16 @@ ManageForeignRoutingPolicyRules=no // configureSystemdNetworkd creates a drop-in configuration file to prevent // systemd-networkd from removing NetBird's routes and policy rules. func configureSystemdNetworkd() error { - parentDir := filepath.Dir(networkdConfDir) - if _, err := os.Stat(parentDir); os.IsNotExist(err) { - log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration") + if _, err := os.Stat(networkdConf); os.IsNotExist(err) { + log.Debug("systemd-networkd not in use, skipping configuration") return nil } + // nolint:gosec // standard networkd permissions + if err := os.MkdirAll(networkdConfDir, 0755); err != nil { + return fmt.Errorf("create networkd.conf.d directory: %w", err) + } + // nolint:gosec // standard networkd permissions if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil { return fmt.Errorf("write networkd configuration: %w", err) From 56f169eedeb4f6a3b8d759bf7f735789322c8bc1 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Mon, 10 Nov 2025 23:43:08 +0100 Subject: [PATCH 068/120] [management] fix pg db deadlock after app panic (#4772) --- management/server/store/sql_store.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 75f2c3ae7..94b7fc1cc 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2914,6 +2914,23 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor if tx.Error != nil { return tx.Error } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + }() + + if s.storeEngine == types.PostgresStoreEngine { + if err := tx.Exec("SET LOCAL statement_timeout = '1min'").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to set statement timeout: %w", err) + } + if err := tx.Exec("SET LOCAL lock_timeout = '1min'").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to set lock timeout: %w", err) + } + } // For MySQL, disable FK checks within this transaction to avoid deadlocks // This is session-scoped and doesn't require SUPER privileges From c28275611b82de1fa04bc022458ae97dd325d7fc Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 11 Nov 2025 13:59:32 +0100 Subject: [PATCH 069/120] Fix agent reference (#4776) --- client/internal/peer/worker_ice.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 3675f0157..5d8ebfe45 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { if isController(w.config) { - return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } else { return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } From cc97cffff1491d84d9cad39bae53a64998b6857b Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 13 Nov 2025 12:09:46 +0100 Subject: [PATCH 070/120] [management] move network map logic into new design (#4774) --- client/cmd/testutil_test.go | 15 +- client/internal/engine_test.go | 13 +- client/server/server_test.go | 13 +- go.mod | 2 +- .../controller/cache/dns_config_cache.go | 31 + .../network_map/controller/controller.go | 784 ++++++++++++++++++ .../network_map/controller/controller_test.go | 244 ++++++ .../network_map/controller/metrics.go | 15 + .../network_map/controller/repository.go | 39 + .../controllers/network_map/interface.go | 39 + .../controllers/network_map/interface_mock.go | 225 +++++ .../controllers/network_map/network_map.go | 1 + .../controllers/network_map/update_channel.go | 13 + .../update_channel}/updatechannel.go | 26 +- .../update_channel}/updatechannel_test.go | 7 +- .../controllers/network_map/update_message.go | 9 + management/internals/server/boot.go | 6 +- management/internals/server/controllers.go | 28 +- management/internals/server/modules.go | 3 +- .../internals/shared/grpc/conversion.go | 352 ++++++++ .../internals/shared/grpc/conversion_test.go | 150 ++++ .../shared/grpc}/loginfilter.go | 2 +- .../shared/grpc}/loginfilter_test.go | 2 +- .../shared/grpc/server.go} | 249 +----- .../internals/shared/grpc/server_test.go | 106 +++ .../shared/grpc}/token_mgr.go | 11 +- .../shared/grpc}/token_mgr_test.go | 12 +- management/server/account.go | 127 +-- management/server/account/manager.go | 11 +- management/server/account/request_buffer.go | 11 + management/server/account_test.go | 151 ++-- management/server/dns.go | 130 --- management/server/dns_test.go | 257 +----- management/server/event_test.go | 2 +- management/server/group.go | 26 +- management/server/group_test.go | 20 +- management/server/holder.go | 39 - management/server/http/handler.go | 4 +- .../http/handlers/peers/peers_handler.go | 23 +- .../http/handlers/peers/peers_handler_test.go | 29 +- .../testing/testing_tools/channel/channel.go | 20 +- management/server/management_proto_test.go | 142 +--- management/server/management_test.go | 21 +- management/server/mock_server/account_mock.go | 12 +- management/server/nameserver.go | 9 - management/server/nameserver_test.go | 16 +- management/server/networkmap.go | 80 -- management/server/networks/manager.go | 3 - .../server/networks/resources/manager.go | 9 - management/server/networks/routers/manager.go | 9 - management/server/peer.go | 446 +--------- management/server/peer_test.go | 235 ++---- management/server/policy.go | 35 - management/server/policy_test.go | 6 +- management/server/posture_checks.go | 72 -- management/server/posture_checks_test.go | 14 +- management/server/route.go | 101 --- management/server/route_test.go | 33 +- management/server/setupkey_test.go | 14 +- management/server/user.go | 23 +- management/server/user_test.go | 16 +- shared/management/client/client_test.go | 14 +- 62 files changed, 2568 insertions(+), 1989 deletions(-) create mode 100644 management/internals/controllers/network_map/controller/cache/dns_config_cache.go create mode 100644 management/internals/controllers/network_map/controller/controller.go create mode 100644 management/internals/controllers/network_map/controller/controller_test.go create mode 100644 management/internals/controllers/network_map/controller/metrics.go create mode 100644 management/internals/controllers/network_map/controller/repository.go create mode 100644 management/internals/controllers/network_map/interface.go create mode 100644 management/internals/controllers/network_map/interface_mock.go create mode 100644 management/internals/controllers/network_map/network_map.go create mode 100644 management/internals/controllers/network_map/update_channel.go rename management/{server => internals/controllers/network_map/update_channel}/updatechannel.go (87%) rename management/{server => internals/controllers/network_map/update_channel}/updatechannel_test.go (89%) create mode 100644 management/internals/controllers/network_map/update_message.go create mode 100644 management/internals/shared/grpc/conversion.go create mode 100644 management/internals/shared/grpc/conversion_test.go rename management/{server => internals/shared/grpc}/loginfilter.go (99%) rename management/{server => internals/shared/grpc}/loginfilter_test.go (99%) rename management/{server/grpcserver.go => internals/shared/grpc/server.go} (76%) create mode 100644 management/internals/shared/grpc/server_test.go rename management/{server => internals/shared/grpc}/token_mgr.go (93%) rename management/{server => internals/shared/grpc}/token_mgr_test.go (94%) create mode 100644 management/server/account/request_buffer.go delete mode 100644 management/server/holder.go delete mode 100644 management/server/networkmap.go diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index bd3209605..78bb0476b 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,9 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" @@ -84,7 +87,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp } t.Cleanup(cleanUp) - peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, nil @@ -110,13 +112,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp Return(&types.Settings{}, nil). AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } - secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 2f1098100..15ac0a947 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -26,6 +26,9 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" @@ -1556,7 +1559,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } t.Cleanup(cleanUp) - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, "", err @@ -1584,13 +1586,16 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri groupsManager := groups.NewManagerMock() - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) + networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err } - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index e0a4805f6..ae5f759ee 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -14,6 +14,9 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" @@ -290,7 +293,6 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } t.Cleanup(cleanUp) - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, "", err @@ -311,13 +313,16 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) + peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) + networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err } - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index 7b9bae321..1e3177d7e 100644 --- a/go.mod +++ b/go.mod @@ -99,6 +99,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.35.0 go.opentelemetry.io/otel/sdk/metric v1.35.0 + go.uber.org/mock v0.5.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 @@ -242,7 +243,6 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - go.uber.org/mock v0.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect diff --git a/management/internals/controllers/network_map/controller/cache/dns_config_cache.go b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go new file mode 100644 index 000000000..8cc634ef4 --- /dev/null +++ b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go @@ -0,0 +1,31 @@ +package cache + +import ( + "sync" + + "github.com/netbirdio/netbird/shared/management/proto" +) + +// DNSConfigCache is a thread-safe cache for DNS configuration components +type DNSConfigCache struct { + NameServerGroups sync.Map +} + +// GetNameServerGroup retrieves a cached name server group +func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { + if c == nil { + return nil, false + } + if value, ok := c.NameServerGroups.Load(key); ok { + return value.(*proto.NameServerGroup), true + } + return nil, false +} + +// SetNameServerGroup stores a name server group in the cache +func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) { + if c == nil { + return + } + c.NameServerGroups.Store(key, value) +} diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go new file mode 100644 index 000000000..ad25494c7 --- /dev/null +++ b/management/internals/controllers/network_map/controller/controller.go @@ -0,0 +1,784 @@ +package controller + +import ( + "context" + "errors" + "fmt" + "os" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "golang.org/x/mod/semver" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/util" +) + +type Controller struct { + repo Repository + metrics *metrics + // This should not be here, but we need to maintain it for the time being + accountManagerMetrics *telemetry.AccountManagerMetrics + peersUpdateManager network_map.PeersUpdateManager + settingsManager settings.Manager + + accountUpdateLocks sync.Map + sendAccountUpdateLocks sync.Map + updateAccountPeersBufferInterval atomic.Int64 + // dnsDomain is used for peer resolution. This is appended to the peer's name + dnsDomain string + + requestBuffer account.RequestBuffer + + proxyController port_forwarding.Controller + + integratedPeerValidator integrated_validator.IntegratedValidator + + holder *types.Holder + + expNewNetworkMap bool + expNewNetworkMapAIDs map[string]struct{} +} + +type bufferUpdate struct { + mu sync.Mutex + next *time.Timer + update atomic.Bool +} + +var _ network_map.Controller = (*Controller)(nil) + +func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller { + nMetrics, err := newMetrics(metrics.UpdateChannelMetrics()) + if err != nil { + log.Fatal(fmt.Errorf("error creating metrics: %w", err)) + } + + newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err) + newNetworkMapBuilder = false + } + + ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") + expIDs := make(map[string]struct{}, len(ids)) + for _, id := range ids { + expIDs[id] = struct{}{} + } + + return &Controller{ + repo: newRepository(store), + metrics: nMetrics, + accountManagerMetrics: metrics.AccountManagerMetrics(), + peersUpdateManager: peersUpdateManager, + requestBuffer: requestBuffer, + integratedPeerValidator: integratedPeerValidator, + settingsManager: settingsManager, + dnsDomain: dnsDomain, + + proxyController: proxyController, + + holder: types.NewHolder(), + expNewNetworkMap: newNetworkMapBuilder, + expNewNetworkMapAIDs: expIDs, + } +} + +func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) + var ( + account *types.Account + err error + ) + if c.experimentalNetworkMap(accountID) { + account = c.getAccountFromHolderOrInit(accountID) + } else { + account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get account: %v", err) + } + } + + globalStart := time.Now() + + hasPeersConnected := false + for _, peer := range account.Peers { + if c.peersUpdateManager.HasChannel(peer.ID) { + hasPeersConnected = true + break + } + + } + + if !hasPeersConnected { + return nil + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return fmt.Errorf("failed to get validate peers: %v", err) + } + + var wg sync.WaitGroup + semaphore := make(chan struct{}, 10) + + dnsCache := &cache.DNSConfigCache{} + dnsDomain := c.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + if c.experimentalNetworkMap(accountID) { + c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) + } + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return fmt.Errorf("failed to get proxy network maps: %v", err) + } + + extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get flow enabled status: %v", err) + } + + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + for _, peer := range account.Peers { + if !c.peersUpdateManager.HasChannel(peer.ID) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) + continue + } + + wg.Add(1) + semaphore <- struct{}{} + go func(p *nbpeer.Peer) { + defer wg.Done() + defer func() { <-semaphore }() + + start := time.Now() + + postureChecks, err := c.getPeerPostureChecks(account, p.ID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err) + return + } + + c.metrics.CountCalcPostureChecksDuration(time.Since(start)) + start = time.Now() + + var remotePeerNetworkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountID) { + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + } + + c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + remotePeerNetworkMap.Merge(proxyNetworkMap) + } + + peerGroups := account.GetPeerGroups(p.ID) + start = time.Now() + update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) + c.metrics.CountToSyncResponseDuration(time.Since(start)) + + c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) + }(peer) + } + + wg.Wait() + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart)) + } + + return nil +} + +func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName()) + + bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return nil + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + _ = c.sendUpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { + _ = c.sendUpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load())) + }() + + return nil +} + +// UpdatePeers updates all peers that belong to an account. +// Should be called when changes have to be synced to peers. +func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { + if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return fmt.Errorf("recalculate network map cache: %v", err) + } + + return c.sendUpdateAccountPeers(ctx, accountID) +} + +func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error { + if !c.peersUpdateManager.HasChannel(peerId) { + return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId) + } + + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err) + } + + peer := account.GetPeer(peerId) + if peer == nil { + return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId) + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return fmt.Errorf("failed to get validated peers: %v", err) + } + + dnsCache := &cache.DNSConfigCache{} + dnsDomain := c.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + postureChecks, err := c.getPeerPostureChecks(account, peerId) + if err != nil { + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) + return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err) + } + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return err + } + + var remotePeerNetworkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountId) { + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + remotePeerNetworkMap.Merge(proxyNetworkMap) + } + + extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID) + if err != nil { + return fmt.Errorf("failed to get extra settings: %v", err) + } + + peerGroups := account.GetPeerGroups(peerId) + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) + c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) + + return nil +} + +func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) + + bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return nil + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + _ = c.UpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { + _ = c.UpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load())) + }() + + return nil +} + +func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error { + network, err := c.repo.GetAccountNetwork(ctx, accountId) + if err != nil { + return err + } + + peers, err := c.repo.GetAccountPeers(ctx, accountId) + if err != nil { + return err + } + + dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) + c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, + }, + }, + }) + c.peersUpdateManager.CloseChannel(ctx, peerId) + return nil +} + +func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + if isRequiresApproval { + network, err := c.repo.GetAccountNetwork(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err + } + + emptyMap := &types.NetworkMap{ + Network: network.Copy(), + } + return peer, emptyMap, nil, 0, nil + } + + var ( + account *types.Account + err error + ) + if c.experimentalNetworkMap(accountID) { + account = c.getAccountFromHolderOrInit(accountID) + } else { + account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err + } + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, nil, nil, 0, err + } + + startPosture := time.Now() + postureChecks, err := c.getPeerPostureChecks(account, peer.ID) + if err != nil { + return nil, nil, nil, 0, err + } + log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) + + customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return nil, nil, nil, 0, err + } + + var networkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountID) { + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + networkMap.Merge(proxyNetworkMap) + } + + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + return peer, networkMap, postureChecks, dnsFwdPort, nil +} + +func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { + c.enrichAccountFromHolder(account) + account.InitNetworkMapBuilderIfNeeded(validatedPeers) +} + +func (c *Controller) getPeerNetworkMapExp( + ctx context.Context, + accountId string, + peerId string, + validatedPeers map[string]struct{}, + customZone nbdns.CustomZone, + metrics *telemetry.AccountManagerMetrics, +) *types.NetworkMap { + account := c.getAccountFromHolderOrInit(accountId) + if account == nil { + log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) + return &types.NetworkMap{ + Network: &types.Network{}, + } + } + return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) +} + +func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { + c.enrichAccountFromHolder(account) + return account.OnPeerAddedUpdNetworkMapCache(peerId) +} + +func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { + c.enrichAccountFromHolder(account) + return account.OnPeerDeletedUpdNetworkMapCache(peerId) +} + +func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { + account := c.getAccountFromHolder(accountId) + if account == nil { + return + } + account.UpdatePeerInNetworkMapCache(peer) +} + +func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { + account.RecalculateNetworkMapCache(validatedPeers) + c.updateAccountInHolder(account) +} + +func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { + if c.experimentalNetworkMap(accountId) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return err + } + validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) + return err + } + c.recalculateNetworkMapCache(account, validatedPeers) + } + return nil +} + +func (c *Controller) experimentalNetworkMap(accountId string) bool { + _, ok := c.expNewNetworkMapAIDs[accountId] + return c.expNewNetworkMap || ok +} + +func (c *Controller) enrichAccountFromHolder(account *types.Account) { + a := c.holder.GetAccount(account.Id) + if a == nil { + c.holder.AddAccount(account) + return + } + account.NetworkMapCache = a.NetworkMapCache + if account.NetworkMapCache == nil { + return + } + account.NetworkMapCache.UpdateAccountPointer(account) + c.holder.AddAccount(account) +} + +func (c *Controller) getAccountFromHolder(accountID string) *types.Account { + return c.holder.GetAccount(accountID) +} + +func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account { + a := c.holder.GetAccount(accountID) + if a != nil { + return a + } + account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure) + if err != nil { + return nil + } + return account +} + +func (c *Controller) updateAccountInHolder(account *types.Account) { + c.holder.AddAccount(account) +} + +// GetDNSDomain returns the configured dnsDomain +func (c *Controller) GetDNSDomain(settings *types.Settings) string { + if settings == nil { + return c.dnsDomain + } + if settings.DNSDomain == "" { + return c.dnsDomain + } + + return settings.DNSDomain +} + +// getPeerPostureChecks returns the posture checks applied for a given peer. +func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) + + if len(account.PostureChecks) == 0 { + return nil, nil + } + + for _, policy := range account.Policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue + } + + if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { + return nil, err + } + } + + return maps.Values(peerPostureChecks), nil +} + +func (c *Controller) StartWarmup(ctx context.Context) { + var initialInterval int64 + intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") + interval, err := strconv.Atoi(intervalStr) + if err != nil { + initialInterval = 1 + log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err) + } else { + initialInterval = int64(interval) * 10 + go func() { + startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S") + startupPeriod, err := strconv.Atoi(startupPeriodStr) + if err != nil { + startupPeriod = 1 + log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err) + } + time.Sleep(time.Duration(startupPeriod) * time.Second) + c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond)) + log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval) + }() + } + c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond)) + log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval) + +} + +// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. +// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. +func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { + if len(peers) == 0 { + return int64(network_map.OldForwarderPort) + } + + reqVer := semver.Canonical(requiredVersion) + + // Check if all peers have the required version or newer + for _, peer := range peers { + + // Development version is always supported + if peer.Meta.WtVersion == "development" { + continue + } + peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) + if peerVersion == "" { + // If any peer doesn't have version info, return 0 + return int64(network_map.OldForwarderPort) + } + + // Compare versions + if semver.Compare(peerVersion, reqVer) < 0 { + return int64(network_map.OldForwarderPort) + } + } + + // All peers have the required version or newer + return int64(network_map.DnsForwarderPort) +} + +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) + if err != nil { + return err + } + + if !isInGroup { + return nil + } + + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck := account.GetPostureChecks(sourcePostureCheckID) + if postureCheck == nil { + return errors.New("failed to add policy posture checks: posture checks not found") + } + peerPostureChecks[sourcePostureCheckID] = postureCheck + } + + return nil +} + +// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. +func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, sourceGroup := range rule.Sources { + group := account.GetGroup(sourceGroup) + if group == nil { + return false, fmt.Errorf("failed to check peer in policy source group: group not found") + } + + if slices.Contains(group.Peers, peerID) { + return true, nil + } + } + } + + return false, nil +} + +func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) { + c.UpdatePeerInNetworkMapCache(accountId, peer) + _ = c.bufferSendUpdateAccountPeers(context.Background(), accountId) +} + +func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + + err = c.onPeerAddedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } + } + return c.bufferSendUpdateAccountPeers(ctx, accountID) +} + +func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } + } + + return c.bufferSendUpdateAccountPeers(ctx, accountID) +} + +// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) +func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { + account, err := c.repo.GetAccountByPeerID(ctx, peerID) + if err != nil { + return nil, err + } + + peer := account.GetPeer(peerID) + if peer == nil { + return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) + } + + groups := make(map[string][]string) + for groupID, group := range account.Groups { + groups[groupID] = group.Peers + } + + validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, err + } + customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return nil, err + } + + var networkMap *types.NetworkMap + + if c.experimentalNetworkMap(peer.AccountID) { + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + networkMap.Merge(proxyNetworkMap) + } + + return networkMap, nil +} + +func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) { + c.peersUpdateManager.CloseChannels(ctx, peerIDs) +} + +func (c *Controller) IsConnected(peerID string) bool { + return c.peersUpdateManager.HasChannel(peerID) +} diff --git a/management/internals/controllers/network_map/controller/controller_test.go b/management/internals/controllers/network_map/controller/controller_test.go new file mode 100644 index 000000000..baaffe677 --- /dev/null +++ b/management/internals/controllers/network_map/controller/controller_test.go @@ -0,0 +1,244 @@ +package controller + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/server/mock_server" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func TestComputeForwarderPort(t *testing.T) { + // Test with empty peers list + peers := []*nbpeer.Peer{} + result := computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for empty peers list, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have old versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.26.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with old versions, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have new versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.DnsForwarderPort) { + t.Errorf("Expected %d for peers with new versions, got %d", network_map.DnsForwarderPort, result) + } + + // Test with peers that have mixed versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with mixed versions, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have empty version + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with empty version, got %d", network_map.OldForwarderPort, result) + } + + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result == int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with dev version, got %d", network_map.DnsForwarderPort, result) + } + + // Test with peers that have unknown version string + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "unknown", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result) + } +} + +func TestBufferUpdateAccountPeers(t *testing.T) { + const ( + peersCount = 1000 + updateAccountInterval = 50 * time.Millisecond + ) + + var ( + deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 + uapLastRun, dpLastRun atomic.Int64 + + totalNewRuns, totalOldRuns int + ) + + uap := func(ctx context.Context, accountID string) { + updatePeersDeleted.Store(deletedPeers.Load()) + updatePeersRuns.Add(1) + uapLastRun.Store(time.Now().UnixMilli()) + time.Sleep(100 * time.Millisecond) + } + + t.Run("new approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) + b := mu.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + uap(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + b.next = time.AfterFunc(updateAccountInterval, func() { + uap(ctx, accountID) + }) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalNewRuns = int(updatePeersRuns.Load()) + }) + + t.Run("old approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) + b := mu.(*sync.Mutex) + + if !b.TryLock() { + return + } + + go func() { + time.Sleep(updateAccountInterval) + b.Unlock() + uap(ctx, accountID) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalOldRuns = int(updatePeersRuns.Load()) + }) + assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) + t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) +} diff --git a/management/internals/controllers/network_map/controller/metrics.go b/management/internals/controllers/network_map/controller/metrics.go new file mode 100644 index 000000000..5832d2130 --- /dev/null +++ b/management/internals/controllers/network_map/controller/metrics.go @@ -0,0 +1,15 @@ +package controller + +import ( + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type metrics struct { + *telemetry.UpdateChannelMetrics +} + +func newMetrics(updateChannelMetrics *telemetry.UpdateChannelMetrics) (*metrics, error) { + return &metrics{ + updateChannelMetrics, + }, nil +} diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go new file mode 100644 index 000000000..44144263b --- /dev/null +++ b/management/internals/controllers/network_map/controller/repository.go @@ -0,0 +1,39 @@ +package controller + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Repository interface { + GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) + GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) +} + +type repository struct { + store store.Store +} + +var _ Repository = (*repository)(nil) + +func newRepository(s store.Store) Repository { + return &repository{ + store: s, + } +} + +func (r *repository) GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) { + return r.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) +} + +func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) { + return r.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") +} + +func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { + return r.store.GetAccountByPeerID(ctx, peerID) +} diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go new file mode 100644 index 000000000..6f893ce79 --- /dev/null +++ b/management/internals/controllers/network_map/interface.go @@ -0,0 +1,39 @@ +package network_map + +//go:generate go run go.uber.org/mock/mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" +) + +const ( + EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" + EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" + + DnsForwarderPort = nbdns.ForwarderServerPort + OldForwarderPort = nbdns.ForwarderClientPort + DnsForwarderPortMinVersion = "v0.59.0" +) + +type Controller interface { + UpdateAccountPeers(ctx context.Context, accountID string) error + UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error + BufferUpdateAccountPeers(ctx context.Context, accountID string) error + GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + GetDNSDomain(settings *types.Settings) string + StartWarmup(context.Context) + GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + + DeletePeer(ctx context.Context, accountId string, peerId string) error + + OnPeerUpdated(accountId string, peer *nbpeer.Peer) + OnPeerAdded(ctx context.Context, accountID string, peerID string) error + OnPeerDeleted(ctx context.Context, accountID string, peerID string) error + DisconnectPeers(ctx context.Context, peerIDs []string) + IsConnected(peerID string) bool +} diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go new file mode 100644 index 000000000..aaa093e47 --- /dev/null +++ b/management/internals/controllers/network_map/interface_mock.go @@ -0,0 +1,225 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./interface.go +// +// Generated by this command: +// +// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod +// + +// Package network_map is a generated GoMock package. +package network_map + +import ( + context "context" + reflect "reflect" + + peer "github.com/netbirdio/netbird/management/server/peer" + posture "github.com/netbirdio/netbird/management/server/posture" + types "github.com/netbirdio/netbird/management/server/types" + gomock "go.uber.org/mock/gomock" +) + +// MockController is a mock of Controller interface. +type MockController struct { + ctrl *gomock.Controller + recorder *MockControllerMockRecorder + isgomock struct{} +} + +// MockControllerMockRecorder is the mock recorder for MockController. +type MockControllerMockRecorder struct { + mock *MockController +} + +// NewMockController creates a new mock instance. +func NewMockController(ctrl *gomock.Controller) *MockController { + mock := &MockController{ctrl: ctrl} + mock.recorder = &MockControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockController) EXPECT() *MockControllerMockRecorder { + return m.recorder +} + +// BufferUpdateAccountPeers mocks base method. +func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. +func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) +} + +// DeletePeer mocks base method. +func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePeer indicates an expected call of DeletePeer. +func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId) +} + +// DisconnectPeers mocks base method. +func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs) +} + +// DisconnectPeers indicates an expected call of DisconnectPeers. +func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs) +} + +// GetDNSDomain mocks base method. +func (m *MockController) GetDNSDomain(settings *types.Settings) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDNSDomain", settings) + ret0, _ := ret[0].(string) + return ret0 +} + +// GetDNSDomain indicates an expected call of GetDNSDomain. +func (mr *MockControllerMockRecorder) GetDNSDomain(settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSDomain", reflect.TypeOf((*MockController)(nil).GetDNSDomain), settings) +} + +// GetNetworkMap mocks base method. +func (m *MockController) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID) + ret0, _ := ret[0].(*types.NetworkMap) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNetworkMap indicates an expected call of GetNetworkMap. +func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockController)(nil).GetNetworkMap), ctx, peerID) +} + +// GetValidatedPeerWithMap mocks base method. +func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(*types.NetworkMap) + ret2, _ := ret[2].([]*posture.Checks) + ret3, _ := ret[3].(int64) + ret4, _ := ret[4].(error) + return ret0, ret1, ret2, ret3, ret4 +} + +// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap. +func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) +} + +// IsConnected mocks base method. +func (m *MockController) IsConnected(peerID string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsConnected", peerID) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsConnected indicates an expected call of IsConnected. +func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID) +} + +// OnPeerAdded mocks base method. +func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeerAdded indicates an expected call of OnPeerAdded. +func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID) +} + +// OnPeerDeleted mocks base method. +func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeerDeleted indicates an expected call of OnPeerDeleted. +func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID) +} + +// OnPeerUpdated mocks base method. +func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPeerUpdated", accountId, peer) +} + +// OnPeerUpdated indicates an expected call of OnPeerUpdated. +func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer) +} + +// StartWarmup mocks base method. +func (m *MockController) StartWarmup(arg0 context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartWarmup", arg0) +} + +// StartWarmup indicates an expected call of StartWarmup. +func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0) +} + +// UpdateAccountPeer mocks base method. +func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountPeer", ctx, accountId, peerId) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountPeer indicates an expected call of UpdateAccountPeer. +func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeer", reflect.TypeOf((*MockController)(nil).UpdateAccountPeer), ctx, accountId, peerId) +} + +// UpdateAccountPeers mocks base method. +func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountPeers indicates an expected call of UpdateAccountPeers. +func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID) +} diff --git a/management/internals/controllers/network_map/network_map.go b/management/internals/controllers/network_map/network_map.go new file mode 100644 index 000000000..e915c2193 --- /dev/null +++ b/management/internals/controllers/network_map/network_map.go @@ -0,0 +1 @@ +package network_map diff --git a/management/internals/controllers/network_map/update_channel.go b/management/internals/controllers/network_map/update_channel.go new file mode 100644 index 000000000..0b085b85f --- /dev/null +++ b/management/internals/controllers/network_map/update_channel.go @@ -0,0 +1,13 @@ +package network_map + +import "context" + +type PeersUpdateManager interface { + SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) + CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage + CloseChannel(ctx context.Context, peerID string) + CountStreams() int + HasChannel(peerID string) bool + CloseChannels(ctx context.Context, peerIDs []string) + GetAllConnectedPeers() map[string]struct{} +} diff --git a/management/server/updatechannel.go b/management/internals/controllers/network_map/update_channel/updatechannel.go similarity index 87% rename from management/server/updatechannel.go rename to management/internals/controllers/network_map/update_channel/updatechannel.go index adf64592a..5f7db5300 100644 --- a/management/server/updatechannel.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel.go @@ -1,4 +1,4 @@ -package server +package update_channel import ( "context" @@ -7,36 +7,34 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/shared/management/proto" ) const channelBufferSize = 100 -type UpdateMessage struct { - Update *proto.SyncResponse -} - type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID - peerChannels map[string]chan *UpdateMessage + peerChannels map[string]chan *network_map.UpdateMessage // channelsMux keeps the mutex to access peerChannels channelsMux *sync.RWMutex // metrics provides method to collect application metrics metrics telemetry.AppMetrics } +var _ network_map.PeersUpdateManager = (*PeersUpdateManager)(nil) + // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), + peerChannels: make(map[string]chan *network_map.UpdateMessage), channelsMux: &sync.RWMutex{}, metrics: metrics, } } // SendUpdate sends update message to the peer's channel -func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) { +func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *network_map.UpdateMessage) { start := time.Now() var found, dropped bool @@ -64,7 +62,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. -func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { +func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *network_map.UpdateMessage { start := time.Now() closed := false @@ -83,7 +81,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c close(channel) } // mbragin: todo shouldn't it be more? or configurable? - channel := make(chan *UpdateMessage, channelBufferSize) + channel := make(chan *network_map.UpdateMessage, channelBufferSize) p.peerChannels[peerID] = channel log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) @@ -174,3 +172,9 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } + +func (p *PeersUpdateManager) CountStreams() int { + p.channelsMux.RLock() + defer p.channelsMux.RUnlock() + return len(p.peerChannels) +} diff --git a/management/server/updatechannel_test.go b/management/internals/controllers/network_map/update_channel/updatechannel_test.go similarity index 89% rename from management/server/updatechannel_test.go rename to management/internals/controllers/network_map/update_channel/updatechannel_test.go index 0dc86563d..afc1e2c32 100644 --- a/management/server/updatechannel_test.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel_test.go @@ -1,10 +1,11 @@ -package server +package update_channel import ( "context" "testing" "time" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -24,7 +25,7 @@ func TestCreateChannel(t *testing.T) { func TestSendUpdate(t *testing.T) { peer := "test-sendupdate" peersUpdater := NewPeersUpdateManager(nil) - update1 := &UpdateMessage{Update: &proto.SyncResponse{ + update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ NetworkMap: &proto.NetworkMap{ Serial: 0, }, @@ -44,7 +45,7 @@ func TestSendUpdate(t *testing.T) { peersUpdater.SendUpdate(context.Background(), peer, update1) } - update2 := &UpdateMessage{Update: &proto.SyncResponse{ + update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ NetworkMap: &proto.NetworkMap{ Serial: 10, }, diff --git a/management/internals/controllers/network_map/update_message.go b/management/internals/controllers/network_map/update_message.go new file mode 100644 index 000000000..33643bcbd --- /dev/null +++ b/management/internals/controllers/network_map/update_message.go @@ -0,0 +1,9 @@ +package network_map + +import ( + "github.com/netbirdio/netbird/shared/management/proto" +) + +type UpdateMessage struct { + Update *proto.SyncResponse +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 16e93a549..eadd16c2d 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -22,7 +22,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" @@ -93,7 +93,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator()) + srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create management server: %v", err) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index ddd81daa2..b61e33688 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,6 +6,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" @@ -14,9 +18,9 @@ import ( "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) -func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { - return Create(s, func() *server.PeersUpdateManager { - return server.NewPeersUpdateManager(s.Metrics()) +func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { + return Create(s, func() *update_channel.PeersUpdateManager { + return update_channel.NewPeersUpdateManager(s.Metrics()) }) } @@ -40,9 +44,9 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller { }) } -func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager { - return Create(s, func() *server.TimeBasedAuthSecretsManager { - return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) +func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager { + return Create(s, func() *grpc.TimeBasedAuthSecretsManager { + return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) }) } @@ -63,3 +67,15 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager { return manager.NewEphemeralManager(s.Store(), s.AccountManager()) }) } + +func (s *BaseServer) NetworkMapController() network_map.Controller { + return Create(s, func() *nmapcontroller.Controller { + return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController()) + }) +} + +func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { + return Create(s, func() *server.AccountRequestBuffer { + return server.NewAccountRequestBuffer(context.Background(), s.Store()) + }) +} diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 209a20065..409bdaaba 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -66,8 +66,7 @@ func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, - s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) if err != nil { log.Fatalf("failed to create account manager: %v", err) } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go new file mode 100644 index 000000000..9a4681eae --- /dev/null +++ b/management/internals/shared/grpc/conversion.go @@ -0,0 +1,352 @@ +package grpc + +import ( + "context" + "fmt" + + integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { + if config == nil { + return nil + } + + var stuns []*proto.HostConfig + for _, stun := range config.Stuns { + stuns = append(stuns, &proto.HostConfig{ + Uri: stun.URI, + Protocol: ToResponseProto(stun.Proto), + }) + } + + var turns []*proto.ProtectedHostConfig + if config.TURNConfig != nil { + for _, turn := range config.TURNConfig.Turns { + var username string + var password string + if turnCredentials != nil { + username = turnCredentials.Payload + password = turnCredentials.Signature + } else { + username = turn.Username + password = turn.Password + } + turns = append(turns, &proto.ProtectedHostConfig{ + HostConfig: &proto.HostConfig{ + Uri: turn.URI, + Protocol: ToResponseProto(turn.Proto), + }, + User: username, + Password: password, + }) + } + } + + var relayCfg *proto.RelayConfig + if config.Relay != nil && len(config.Relay.Addresses) > 0 { + relayCfg = &proto.RelayConfig{ + Urls: config.Relay.Addresses, + } + + if relayToken != nil { + relayCfg.TokenPayload = relayToken.Payload + relayCfg.TokenSignature = relayToken.Signature + } + } + + var signalCfg *proto.HostConfig + if config.Signal != nil { + signalCfg = &proto.HostConfig{ + Uri: config.Signal.URI, + Protocol: ToResponseProto(config.Signal.Proto), + } + } + + nbConfig := &proto.NetbirdConfig{ + Stuns: stuns, + Turns: turns, + Signal: signalCfg, + Relay: relayCfg, + } + + return nbConfig +} + +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { + netmask, _ := network.Net.Mask.Size() + fqdn := peer.FQDN(dnsName) + return &proto.PeerConfig{ + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network + SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Fqdn: fqdn, + RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, + LazyConnectionEnabled: settings.LazyConnectionEnabled, + } +} + +func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { + response := &proto.SyncResponse{ + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), + NetworkMap: &proto.NetworkMap{ + Serial: networkMap.Network.CurrentSerial(), + Routes: toProtocolRoutes(networkMap.Routes), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), + }, + Checks: toProtocolChecks(ctx, checks), + } + + nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) + extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) + response.NetbirdConfig = extendedConfig + + response.NetworkMap.PeerConfig = response.PeerConfig + + remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) + remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) + response.RemotePeers = remotePeers + response.NetworkMap.RemotePeers = remotePeers + response.RemotePeersIsEmpty = len(remotePeers) == 0 + response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty + + response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) + + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + response.NetworkMap.FirewallRules = firewallRules + response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 + + routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) + response.NetworkMap.RoutesFirewallRules = routesFirewallRules + response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 + + if networkMap.ForwardingRules != nil { + forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules)) + for _, rule := range networkMap.ForwardingRules { + forwardingRules = append(forwardingRules, rule.ToProto()) + } + response.NetworkMap.ForwardingRules = forwardingRules + } + + return response +} + +func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { + for _, rPeer := range peers { + dst = append(dst, &proto.RemotePeerConfig{ + WgPubKey: rPeer.Key, + AllowedIps: []string{rPeer.IP.String() + "/32"}, + SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, + Fqdn: rPeer.FQDN(dnsName), + AgentVersion: rPeer.Meta.WtVersion, + }) + } + return dst +} + +// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache +func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig { + protoUpdate := &proto.DNSConfig{ + ServiceEnable: update.ServiceEnable, + CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), + NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), + ForwarderPort: forwardPort, + } + + for _, zone := range update.CustomZones { + protoZone := convertToProtoCustomZone(zone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) + } + + for _, nsGroup := range update.NameServerGroups { + cacheKey := nsGroup.ID + if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists { + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) + } else { + protoGroup := convertToProtoNameServerGroup(nsGroup) + cache.SetNameServerGroup(cacheKey, protoGroup) + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) + } + } + + return protoUpdate +} + +func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { + switch configProto { + case nbconfig.UDP: + return proto.HostConfig_UDP + case nbconfig.DTLS: + return proto.HostConfig_DTLS + case nbconfig.HTTP: + return proto.HostConfig_HTTP + case nbconfig.HTTPS: + return proto.HostConfig_HTTPS + case nbconfig.TCP: + return proto.HostConfig_TCP + default: + panic(fmt.Errorf("unexpected config protocol type %v", configProto)) + } +} + +func toProtocolRoutes(routes []*route.Route) []*proto.Route { + protoRoutes := make([]*proto.Route, 0, len(routes)) + for _, r := range routes { + protoRoutes = append(protoRoutes, toProtocolRoute(r)) + } + return protoRoutes +} + +func toProtocolRoute(route *route.Route) *proto.Route { + return &proto.Route{ + ID: string(route.ID), + NetID: string(route.NetID), + Network: route.Network.String(), + Domains: route.Domains.ToPunycodeList(), + NetworkType: int64(route.NetworkType), + Peer: route.Peer, + Metric: int64(route.Metric), + Masquerade: route.Masquerade, + KeepRoute: route.KeepRoute, + SkipAutoApply: route.SkipAutoApply, + } +} + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + fwRule := &proto.FirewallRule{ + PolicyID: []byte(rule.PolicyID), + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, + } + + if shouldUsePortRange(fwRule) { + fwRule.PortInfo = rule.PortRange.ToProto() + } + + result[i] = fwRule + } + return result +} + +// getProtoDirection converts the direction to proto.RuleDirection. +func getProtoDirection(direction int) proto.RuleDirection { + if direction == types.FirewallRuleDirectionOUT { + return proto.RuleDirection_OUT + } + return proto.RuleDirection_IN +} + +func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { + result := make([]*proto.RouteFirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + result[i] = &proto.RouteFirewallRule{ + SourceRanges: rule.SourceRanges, + Action: getProtoAction(rule.Action), + Destination: rule.Destination, + Protocol: getProtoProtocol(rule.Protocol), + PortInfo: getProtoPortInfo(rule), + IsDynamic: rule.IsDynamic, + Domains: rule.Domains.ToPunycodeList(), + PolicyID: []byte(rule.PolicyID), + RouteID: string(rule.RouteID), + } + } + + return result +} + +// getProtoAction converts the action to proto.RuleAction. +func getProtoAction(action string) proto.RuleAction { + if action == string(types.PolicyTrafficActionDrop) { + return proto.RuleAction_DROP + } + return proto.RuleAction_ACCEPT +} + +// getProtoProtocol converts the protocol to proto.RuleProtocol. +func getProtoProtocol(protocol string) proto.RuleProtocol { + switch types.PolicyRuleProtocolType(protocol) { + case types.PolicyRuleProtocolALL: + return proto.RuleProtocol_ALL + case types.PolicyRuleProtocolTCP: + return proto.RuleProtocol_TCP + case types.PolicyRuleProtocolUDP: + return proto.RuleProtocol_UDP + case types.PolicyRuleProtocolICMP: + return proto.RuleProtocol_ICMP + default: + return proto.RuleProtocol_UNKNOWN + } +} + +// getProtoPortInfo converts the port info to proto.PortInfo. +func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { + var portInfo proto.PortInfo + if rule.Port != 0 { + portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} + } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { + portInfo.PortSelection = &proto.PortInfo_Range_{ + Range: &proto.PortInfo_Range{ + Start: uint32(portRange.Start), + End: uint32(portRange.End), + }, + } + } + return &portInfo +} + +func shouldUsePortRange(rule *proto.FirewallRule) bool { + return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP) +} + +// Helper function to convert nbdns.CustomZone to proto.CustomZone +func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { + protoZone := &proto.CustomZone{ + Domain: zone.Domain, + Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), + } + for _, record := range zone.Records { + protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ + Name: record.Name, + Type: int64(record.Type), + Class: record.Class, + TTL: int64(record.TTL), + RData: record.RData, + }) + } + return protoZone +} + +// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup +func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { + protoGroup := &proto.NameServerGroup{ + Primary: nsGroup.Primary, + Domains: nsGroup.Domains, + SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, + NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), + } + for _, ns := range nsGroup.NameServers { + protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ + IP: ns.IP.String(), + Port: int64(ns.Port), + NSType: int64(ns.NSType), + }) + } + return protoGroup +} diff --git a/management/internals/shared/grpc/conversion_test.go b/management/internals/shared/grpc/conversion_test.go new file mode 100644 index 000000000..701271345 --- /dev/null +++ b/management/internals/shared/grpc/conversion_test.go @@ -0,0 +1,150 @@ +package grpc + +import ( + "fmt" + "net/netip" + "reflect" + "testing" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" +) + +func TestToProtocolDNSConfigWithCache(t *testing.T) { + var cache cache.DNSConfigCache + + // Create two different configs + config1 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.com", + Records: []nbdns.SimpleRecord{ + {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group1", + Name: "Group 1", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.8.8"), Port: 53}, + }, + }, + }, + } + + config2 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.org", + Records: []nbdns.SimpleRecord{ + {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group2", + Name: "Group 2", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.4.4"), Port: 53}, + }, + }, + }, + } + + // First run with config1 + result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) + + // Second run with config2 + result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort)) + + // Third run with config1 again + result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) + + // Verify that result1 and result3 are identical + if !reflect.DeepEqual(result1, result3) { + t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3) + } + + // Verify that result2 is different from result1 and result3 + if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) { + t.Errorf("Results should be different for different inputs") + } + + if _, exists := cache.GetNameServerGroup("group1"); !exists { + t.Errorf("Cache should contain name server group 'group1'") + } + + if _, exists := cache.GetNameServerGroup("group2"); !exists { + t.Errorf("Cache should contain name server group 'group2'") + } +} + +func BenchmarkToProtocolDNSConfig(b *testing.B) { + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + testData := generateTestData(size) + + b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) { + cache := &cache.DNSConfigCache{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) + } + }) + + b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache := &cache.DNSConfigCache{} + toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) + } + }) + } +} + +func generateTestData(size int) nbdns.Config { + config := nbdns.Config{ + ServiceEnable: true, + CustomZones: make([]nbdns.CustomZone, size), + NameServerGroups: make([]*nbdns.NameServerGroup, size), + } + + for i := 0; i < size; i++ { + config.CustomZones[i] = nbdns.CustomZone{ + Domain: fmt.Sprintf("domain%d.com", i), + Records: []nbdns.SimpleRecord{ + { + Name: fmt.Sprintf("record%d", i), + Type: 1, + Class: "IN", + TTL: 3600, + RData: "192.168.1.1", + }, + }, + } + + config.NameServerGroups[i] = &nbdns.NameServerGroup{ + ID: fmt.Sprintf("group%d", i), + Primary: i == 0, + Domains: []string{fmt.Sprintf("domain%d.com", i)}, + SearchDomainsEnabled: true, + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + Port: 53, + NSType: 1, + }, + }, + } + } + + return config +} diff --git a/management/server/loginfilter.go b/management/internals/shared/grpc/loginfilter.go similarity index 99% rename from management/server/loginfilter.go rename to management/internals/shared/grpc/loginfilter.go index 8604af6e2..59f69dd90 100644 --- a/management/server/loginfilter.go +++ b/management/internals/shared/grpc/loginfilter.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "hash/fnv" diff --git a/management/server/loginfilter_test.go b/management/internals/shared/grpc/loginfilter_test.go similarity index 99% rename from management/server/loginfilter_test.go rename to management/internals/shared/grpc/loginfilter_test.go index 65782dd9d..8b26e14ab 100644 --- a/management/server/loginfilter_test.go +++ b/management/internals/shared/grpc/loginfilter_test.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "hash/fnv" diff --git a/management/server/grpcserver.go b/management/internals/shared/grpc/server.go similarity index 76% rename from management/server/grpcserver.go rename to management/internals/shared/grpc/server.go index d3d94443a..08a840316 100644 --- a/management/server/grpcserver.go +++ b/management/internals/shared/grpc/server.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -22,7 +22,7 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/peers/ephemeral" @@ -51,13 +51,13 @@ const ( defaultSyncLim = 1000 ) -// GRPCServer an instance of a Management gRPC API server -type GRPCServer struct { +// Server an instance of a Management gRPC API server +type Server struct { accountManager account.Manager settingsManager settings.Manager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager *PeersUpdateManager + peersUpdateManager network_map.PeersUpdateManager config *nbconfig.Config secretsManager SecretsManager appMetrics telemetry.AppMetrics @@ -69,23 +69,27 @@ type GRPCServer struct { blockPeersWithSameConfig bool integratedPeerValidator integrated_validator.IntegratedValidator + loginFilter *loginFilter + + networkMapController network_map.Controller + syncSem atomic.Int32 syncLim int32 } // NewServer creates a new Management server func NewServer( - ctx context.Context, config *nbconfig.Config, accountManager account.Manager, settingsManager settings.Manager, - peersUpdateManager *PeersUpdateManager, + peersUpdateManager network_map.PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, -) (*GRPCServer, error) { + networkMapController network_map.Controller, +) (*Server, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err @@ -94,7 +98,7 @@ func NewServer( if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { - return int64(len(peersUpdateManager.peerChannels)) + return int64(peersUpdateManager.CountStreams()) }) if err != nil { return nil, err @@ -115,7 +119,7 @@ func NewServer( } } - return &GRPCServer{ + return &Server{ wgKey: key, // peerKey -> event channel peersUpdateManager: peersUpdateManager, @@ -129,12 +133,15 @@ func NewServer( logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, + networkMapController: networkMapController, + + loginFilter: newLoginFilter(), syncLim: syncLim, }, nil } -func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { +func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { ip := "" p, ok := peer.FromContext(ctx) if ok { @@ -171,7 +178,7 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) -func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { if s.syncSem.Load() >= s.syncLim { return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") } @@ -191,7 +198,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi sRealIP := realIP.String() peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() } @@ -245,35 +252,29 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + metahash := metaHash(peerMeta, realIP.String()) + s.loginFilter.addLogin(peerKey.String(), metahash) + + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) s.syncSem.Add(-1) return mapError(ctx, err) } - log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart)) - - err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) return err } - log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart)) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) - log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart)) - s.ephemeralManager.OnPeerConnected(ctx, peer) - log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart)) - s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) - log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart)) - if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } @@ -281,15 +282,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi unlock() unlock = nil - log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart)) - s.syncSem.Add(-1) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) for { select { @@ -323,7 +322,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { s.cancelPeerRoutines(ctx, accountID, peer) @@ -341,7 +340,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w return nil } -func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { +func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() @@ -356,7 +355,7 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) } -func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { +func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { if s.authManager == nil { return "", status.Errorf(codes.Internal, "missing auth manager") } @@ -390,7 +389,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string return userAuth.UserId, nil } -func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { +func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID) start := time.Now() @@ -498,7 +497,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee } } -func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { +func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) @@ -517,7 +516,7 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case it is, the login is successful // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case of the successful registration login is also successful -func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { reqStart := time.Now() realIP := getRealIP(ctx) sRealIP := realIP.String() @@ -531,7 +530,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p peerMeta := extractPeerMeta(ctx, loginReq.GetMeta()) metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.logBlockedPeers { log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) } @@ -628,7 +627,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p }, nil } -func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { +func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { var relayToken *Token var err error if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { @@ -647,7 +646,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings), Checks: toProtocolChecks(ctx, postureChecks), } @@ -659,7 +658,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer // // The user ID can be empty if the token is not provided, which is acceptable if the peer is already // registered or if it uses a setup key to register. -func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { +func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { userID := "" if loginReq.GetJwtToken() != "" { var err error @@ -679,166 +678,13 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR return userID, nil } -func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { - switch configProto { - case nbconfig.UDP: - return proto.HostConfig_UDP - case nbconfig.DTLS: - return proto.HostConfig_DTLS - case nbconfig.HTTP: - return proto.HostConfig_HTTP - case nbconfig.HTTPS: - return proto.HostConfig_HTTPS - case nbconfig.TCP: - return proto.HostConfig_TCP - default: - panic(fmt.Errorf("unexpected config protocol type %v", configProto)) - } -} - -func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { - if config == nil { - return nil - } - - var stuns []*proto.HostConfig - for _, stun := range config.Stuns { - stuns = append(stuns, &proto.HostConfig{ - Uri: stun.URI, - Protocol: ToResponseProto(stun.Proto), - }) - } - - var turns []*proto.ProtectedHostConfig - if config.TURNConfig != nil { - for _, turn := range config.TURNConfig.Turns { - var username string - var password string - if turnCredentials != nil { - username = turnCredentials.Payload - password = turnCredentials.Signature - } else { - username = turn.Username - password = turn.Password - } - turns = append(turns, &proto.ProtectedHostConfig{ - HostConfig: &proto.HostConfig{ - Uri: turn.URI, - Protocol: ToResponseProto(turn.Proto), - }, - User: username, - Password: password, - }) - } - } - - var relayCfg *proto.RelayConfig - if config.Relay != nil && len(config.Relay.Addresses) > 0 { - relayCfg = &proto.RelayConfig{ - Urls: config.Relay.Addresses, - } - - if relayToken != nil { - relayCfg.TokenPayload = relayToken.Payload - relayCfg.TokenSignature = relayToken.Signature - } - } - - var signalCfg *proto.HostConfig - if config.Signal != nil { - signalCfg = &proto.HostConfig{ - Uri: config.Signal.URI, - Protocol: ToResponseProto(config.Signal.Proto), - } - } - - nbConfig := &proto.NetbirdConfig{ - Stuns: stuns, - Turns: turns, - Signal: signalCfg, - Relay: relayCfg, - } - - return nbConfig -} - -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { - netmask, _ := network.Net.Mask.Size() - fqdn := peer.FQDN(dnsName) - return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network - SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, - Fqdn: fqdn, - RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, - LazyConnectionEnabled: settings.LazyConnectionEnabled, - } -} - -func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { - response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), - NetworkMap: &proto.NetworkMap{ - Serial: networkMap.Network.CurrentSerial(), - Routes: toProtocolRoutes(networkMap.Routes), - DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), - }, - Checks: toProtocolChecks(ctx, checks), - } - - nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) - extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) - response.NetbirdConfig = extendedConfig - - response.NetworkMap.PeerConfig = response.PeerConfig - - remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) - response.RemotePeers = remotePeers - response.NetworkMap.RemotePeers = remotePeers - response.RemotePeersIsEmpty = len(remotePeers) == 0 - response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty - - response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) - - firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) - response.NetworkMap.FirewallRules = firewallRules - response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 - - routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) - response.NetworkMap.RoutesFirewallRules = routesFirewallRules - response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 - - if networkMap.ForwardingRules != nil { - forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules)) - for _, rule := range networkMap.ForwardingRules { - forwardingRules = append(forwardingRules, rule.ToProto()) - } - response.NetworkMap.ForwardingRules = forwardingRules - } - - return response -} - -func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { - for _, rPeer := range peers { - dst = append(dst, &proto.RemotePeerConfig{ - WgPubKey: rPeer.Key, - AllowedIps: []string{rPeer.IP.String() + "/32"}, - SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, - Fqdn: rPeer.FQDN(dnsName), - AgentVersion: rPeer.Meta.WtVersion, - }) - } - return dst -} - // IsHealthy indicates whether the service is healthy -func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { +func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { return &proto.Empty{}, nil } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error { var err error var turnToken *Token @@ -862,19 +708,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } - peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID) + peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID) if err != nil { return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - // Get all peers in the account for forwarder port computation - allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "") - if err != nil { - return fmt.Errorf("get account peers: %w", err) - } - dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion) - - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) + plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { @@ -899,7 +738,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p // GetDeviceAuthorizationFlow returns a device authorization flow information // This is used for initiating an Oauth 2 device authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) start := time.Now() defer func() { @@ -957,7 +796,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. // GetPKCEAuthorizationFlow returns a pkce authorization flow information // This is used for initiating an Oauth 2 pkce authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) start := time.Now() defer func() { @@ -1012,7 +851,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En // SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected, // peer's under the same account of any updates. -func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { +func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { realIP := getRealIP(ctx) log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String()) @@ -1037,7 +876,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) return &proto.Empty{}, nil } -func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { +func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) start := time.Now() diff --git a/management/internals/shared/grpc/server_test.go b/management/internals/shared/grpc/server_test.go new file mode 100644 index 000000000..9867b38e3 --- /dev/null +++ b/management/internals/shared/grpc/server_test.go @@ -0,0 +1,106 @@ +package grpc + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/server/config" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { + testingServerKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err) + } + + testingClientKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err) + } + + testCases := []struct { + name string + inputFlow *config.DeviceAuthorizationFlow + expectedFlow *mgmtProto.DeviceAuthorizationFlow + expectedErrFunc require.ErrorAssertionFunc + expectedErrMSG string + expectedComparisonFunc require.ComparisonAssertionFunc + expectedComparisonMSG string + }{ + { + name: "Testing No Device Flow Config", + inputFlow: nil, + expectedErrFunc: require.Error, + expectedErrMSG: "should return error", + }, + { + name: "Testing Invalid Device Flow Provider Config", + inputFlow: &config.DeviceAuthorizationFlow{ + Provider: "NoNe", + ProviderConfig: config.ProviderConfig{ + ClientID: "test", + }, + }, + expectedErrFunc: require.Error, + expectedErrMSG: "should return error", + }, + { + name: "Testing Full Device Flow Config", + inputFlow: &config.DeviceAuthorizationFlow{ + Provider: "hosted", + ProviderConfig: config.ProviderConfig{ + ClientID: "test", + }, + }, + expectedFlow: &mgmtProto.DeviceAuthorizationFlow{ + Provider: 0, + ProviderConfig: &mgmtProto.ProviderConfig{ + ClientID: "test", + }, + }, + expectedErrFunc: require.NoError, + expectedErrMSG: "should not return error", + expectedComparisonFunc: require.Equal, + expectedComparisonMSG: "should match", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + mgmtServer := &Server{ + wgKey: testingServerKey, + config: &config.Config{ + DeviceAuthorizationFlow: testCase.inputFlow, + }, + } + + message := &mgmtProto.DeviceAuthorizationFlowRequest{} + + encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) + require.NoError(t, err, "should be able to encrypt message") + + resp, err := mgmtServer.GetDeviceAuthorizationFlow( + context.TODO(), + &mgmtProto.EncryptedMessage{ + WgPubKey: testingClientKey.PublicKey().String(), + Body: encryptedMSG, + }, + ) + testCase.expectedErrFunc(t, err, testCase.expectedErrMSG) + if testCase.expectedComparisonFunc != nil { + flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} + + err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) + require.NoError(t, err, "should be able to decrypt") + + testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) + testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG) + } + }) + } +} diff --git a/management/server/token_mgr.go b/management/internals/shared/grpc/token_mgr.go similarity index 93% rename from management/server/token_mgr.go rename to management/internals/shared/grpc/token_mgr.go index f9293e7a8..e9770db41 100644 --- a/management/server/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/settings" @@ -37,7 +38,7 @@ type TimeBasedAuthSecretsManager struct { relayCfg *nbconfig.Relay turnHmacToken *auth.TimedHMAC relayHmacToken *authv2.Generator - updateManager *PeersUpdateManager + updateManager network_map.PeersUpdateManager settingsManager settings.Manager groupsManager groups.Manager turnCancelMap map[string]chan struct{} @@ -46,7 +47,7 @@ type TimeBasedAuthSecretsManager struct { type Token auth.Token -func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { +func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { mgr := &TimeBasedAuthSecretsManager{ updateManager: updateManager, turnCfg: turnCfg, @@ -227,7 +228,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) } func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) { @@ -251,7 +252,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) } func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { diff --git a/management/server/token_mgr_test.go b/management/internals/shared/grpc/token_mgr_test.go similarity index 94% rename from management/server/token_mgr_test.go rename to management/internals/shared/grpc/token_mgr_test.go index 5c956dc31..06d28d05b 100644 --- a/management/server/token_mgr_test.go +++ b/management/internals/shared/grpc/token_mgr_test.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -13,6 +13,8 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/settings" @@ -31,7 +33,7 @@ var TurnTestHost = &config.Host{ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { ttl := util.Duration{Duration: time.Hour} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) rc := &config.Relay{ Addresses: []string{"localhost:0"}, @@ -80,7 +82,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { ttl := util.Duration{Duration: 2 * time.Second} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) peer := "some_peer" updateChannel := peersManager.CreateChannel(context.Background(), peer) @@ -116,7 +118,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { t.Errorf("expecting peer to be present in the relay cancel map, got not present") } - var updates []*UpdateMessage + var updates []*network_map.UpdateMessage loop: for timeout := time.After(5 * time.Second); ; { @@ -185,7 +187,7 @@ loop: func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { ttl := util.Duration{Duration: time.Hour} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) peer := "some_peer" rc := &config.Relay{ diff --git a/management/server/account.go b/management/server/account.go index f5a5c7b7a..a4b2a752b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -11,10 +11,8 @@ import ( "reflect" "regexp" "slices" - "strconv" "strings" "sync" - "sync/atomic" "time" cacheStore "github.com/eko/gocache/lib/v4/store" @@ -26,6 +24,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter/hook" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -53,9 +52,6 @@ const ( peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" - - envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" - envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" ) type userLoggedInOnce bool @@ -71,7 +67,7 @@ type DefaultAccountManager struct { cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded cacheLoading map[string]chan struct{} - peersUpdateManager *PeersUpdateManager + networkMapController network_map.Controller idpManager idp.Manager cacheManager *nbcache.AccountUserDataCache externalCacheManager nbcache.UserDataCache @@ -91,8 +87,7 @@ type DefaultAccountManager struct { singleAccountMode bool // singleAccountModeDomain is a domain to use in singleAccountMode setup singleAccountModeDomain string - // dnsDomain is used for peer resolution. This is appended to the peer's name - dnsDomain string + peerLoginExpiry Scheduler peerInactivityExpiry Scheduler @@ -106,19 +101,11 @@ type DefaultAccountManager struct { permissionsManager permissions.Manager - accountUpdateLocks sync.Map - updateAccountPeersBufferInterval atomic.Int64 - - loginFilter *loginFilter - disableDefaultPolicy bool - - holder *types.Holder - - expNewNetworkMap bool - expNewNetworkMapAIDs map[string]struct{} } +var _ account.Manager = (*DefaultAccountManager)(nil) + func isUniqueConstraintError(err error) bool { switch { case strings.Contains(err.Error(), "(SQLSTATE 23505)"), @@ -185,10 +172,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] func BuildManager( ctx context.Context, store store.Store, - peersUpdateManager *PeersUpdateManager, + networkMapController network_map.Controller, idpManager idp.Manager, singleAccountModeDomain string, - dnsDomain string, eventStore activity.Store, geo geolocation.Geolocation, userDeleteFromIDPEnabled bool, @@ -204,27 +190,14 @@ func BuildManager( log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start)) }() - newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder)) - if err != nil { - log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err) - newNetworkMapBuilder = false - } - - ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",") - expIDs := make(map[string]struct{}, len(ids)) - for _, id := range ids { - expIDs[id] = struct{}{} - } - am := &DefaultAccountManager{ Store: store, geo: geo, - peersUpdateManager: peersUpdateManager, + networkMapController: networkMapController, idpManager: idpManager, ctx: context.Background(), cacheMux: sync.Mutex{}, cacheLoading: map[string]chan struct{}{}, - dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), peerInactivityExpiry: NewDefaultScheduler(), @@ -235,15 +208,10 @@ func BuildManager( proxyController: proxyController, settingsManager: settingsManager, permissionsManager: permissionsManager, - loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, - holder: types.NewHolder(), - - expNewNetworkMap: newNetworkMapBuilder, - expNewNetworkMapAIDs: expIDs, } - am.startWarmup(ctx) + am.networkMapController.StartWarmup(ctx) accountsCounter, err := store.GetAccountsCounter(ctx) if err != nil { @@ -291,32 +259,6 @@ func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { am.ephemeralManager = em } -func (am *DefaultAccountManager) startWarmup(ctx context.Context) { - var initialInterval int64 - intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") - interval, err := strconv.Atoi(intervalStr) - if err != nil { - initialInterval = 1 - log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err) - } else { - initialInterval = int64(interval) * 10 - go func() { - startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S") - startupPeriod, err := strconv.Atoi(startupPeriodStr) - if err != nil { - startupPeriod = 1 - log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err) - } - time.Sleep(time.Duration(startupPeriod) * time.Second) - am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond)) - log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval) - }() - } - am.updateAccountPeersBufferInterval.Store(initialInterval) - log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval) - -} - func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager { return am.externalCacheManager } @@ -419,9 +361,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } go am.UpdateAccountPeers(ctx, accountID) } @@ -1504,10 +1443,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } if removedGroupAffectsPeers || newGroupsAffectsPeers { - if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil { - return err - } - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } @@ -1667,14 +1602,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool { - return am.loginFilter.allowLogin(wgPubKey, metahash) -} - -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { - return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) + return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) @@ -1682,10 +1613,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } - metahash := metaHash(meta, realIP.String()) - am.loginFilter.addLogin(peerPubKey, metahash) - - return peer, netMap, postureChecks, nil + return peer, netMap, postureChecks, dnsfwdPort, nil } func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { @@ -1702,41 +1630,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return err } - _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) + _, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { - return mapError(ctx, err) + return err } return nil } -// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() -func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { - return am.peersUpdateManager.GetAllConnectedPeers(), nil -} - -// HasConnectedChannel returns true if peers has channel in update manager, otherwise false -func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool { - return am.peersUpdateManager.HasChannel(peerID) -} - var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) func isDomainValid(domain string) bool { return invalidDomainRegexp.MatchString(domain) } -// GetDNSDomain returns the configured dnsDomain -func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string { - if settings == nil { - return am.dnsDomain - } - if settings.DNSDomain == "" { - return am.dnsDomain - } - - return settings.DNSDomain -} - func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) { peers := []*nbpeer.Peer{} log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID) @@ -2159,8 +2065,7 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us if err != nil { return err } - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.OnPeerUpdated(peer.AccountID, peer) } return nil } @@ -2208,7 +2113,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti if err != nil { return fmt.Errorf("get account settings: %w", err) } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) eventMeta := peer.EventMeta(dnsDomain) oldIP := peer.IP.String() diff --git a/management/server/account/manager.go b/management/server/account/manager.go index db377865a..7c174a481 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -89,7 +89,6 @@ type Manager interface { SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - GetDNSDomain(settings *types.Settings) string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) @@ -97,10 +96,8 @@ type Manager interface { GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) - LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API - GetAllConnectedPeers() (map[string]struct{}, error) - HasConnectedChannel(peerID string) bool + LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) @@ -110,7 +107,7 @@ type Manager interface { UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -127,6 +124,4 @@ type Manager interface { GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) - AllowSync(string, uint64) bool - RecalculateNetworkMapCache(ctx context.Context, accountId string) error } diff --git a/management/server/account/request_buffer.go b/management/server/account/request_buffer.go new file mode 100644 index 000000000..eced1929f --- /dev/null +++ b/management/server/account/request_buffer.go @@ -0,0 +1,11 @@ +package account + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/types" +) + +type RequestBuffer interface { + GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) +} diff --git a/management/server/account_test.go b/management/server/account_test.go index 200ba6b98..ee9950796 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -22,6 +22,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" @@ -406,7 +409,7 @@ func TestNewAccount(t *testing.T) { } func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -603,7 +606,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) @@ -644,7 +647,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { userId := "user-id" domain := "test.domain" _ = newAccountWithId(context.Background(), "", userId, domain, false) - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") @@ -705,7 +708,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { } func TestAccountManager_PrivateAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -731,7 +734,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } func TestAccountManager_SetOrUpdateDomain(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -768,7 +771,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } func TestAccountManager_GetAccountByUserID(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -805,7 +808,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string) } func TestAccountManager_GetAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -843,7 +846,7 @@ func TestAccountManager_GetAccount(t *testing.T) { } func TestAccountManager_DeleteAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -924,7 +927,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { DomainCategory: types.PublicCategory, } - am, err := createManager(b) + am, _, err := createManager(b) if err != nil { b.Fatal(err) return @@ -1016,7 +1019,7 @@ func genUsers(p string, n int) map[string]*types.User { } func TestAccountManager_AddPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1086,7 +1089,7 @@ func TestAccountManager_AddPeer(t *testing.T) { } func TestAccountManager_AddPeerWithUserID(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1155,7 +1158,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { } func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_SaveGroup(t) } @@ -1164,7 +1167,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { } func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) group := types.Group{ ID: "groupA", @@ -1190,8 +1193,8 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { }, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1215,7 +1218,7 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { } func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_DeletePolicy(t) } @@ -1224,10 +1227,10 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { } func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { - manager, account, peer1, _, _ := setupNetworkMapTest(t) + manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) // Ensure that we do not receive an update message before the policy is deleted time.Sleep(time.Second) @@ -1258,7 +1261,7 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { } func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_SavePolicy(t) } @@ -1267,7 +1270,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { - manager, account, peer1, peer2, _ := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ AccountID: account.Id, @@ -1280,8 +1283,8 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { return } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1316,7 +1319,7 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_DeletePeer(t) } @@ -1325,7 +1328,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { } func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { - manager, account, peer1, _, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t) group := types.Group{ ID: "groupA", @@ -1354,8 +1357,11 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + // We need to sleep to wait for the buffer peer update + time.Sleep(300 * time.Millisecond) + + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1378,7 +1384,7 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { } func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_DeleteGroup(t) } @@ -1387,10 +1393,10 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { } func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -1457,7 +1463,7 @@ func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { } func TestAccountManager_DeletePeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1538,7 +1544,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy } func TestGetUsersFromAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -1837,7 +1843,7 @@ func hasNilField(x interface{}) error { } func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1852,7 +1858,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { } func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1908,7 +1914,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { } func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1951,7 +1957,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. } func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2013,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test } func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2677,7 +2683,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", "postgres") - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") // create a new account @@ -2919,18 +2925,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { // Fatalf(format string, args ...interface{}) // } -func createManager(t testing.TB) (*DefaultAccountManager, error) { +func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) { t.Helper() store, err := createStore(t) if err != nil { - return nil, err + return nil, nil, err } eventStore := &activity.InMemoryEventStore{} metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) if err != nil { - return nil, err + return nil, nil, err } ctrl := gomock.NewController(t) @@ -2948,12 +2954,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { - return nil, err + return nil, nil, err } - return manager, nil + return manager, updateManager, nil } func createStore(t testing.TB) (store.Store, error) { @@ -2982,10 +2993,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { } } -func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { t.Helper() - manager, err := createManager(t) + manager, updateManager, err := createManager(t) if err != nil { t.Fatal(err) } @@ -3026,10 +3037,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, peer2 := getPeer(manager, setupKey) peer3 := getPeer(manager, setupKey) - return manager, account, peer1, peer2, peer3 + return manager, updateManager, account, peer1, peer2, peer3 } -func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { +func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { case msg := <-updateMessage: @@ -3039,7 +3050,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag } } -func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { +func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { @@ -3077,7 +3088,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3086,16 +3097,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) assert.NoError(b, err) } @@ -3140,7 +3149,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3149,11 +3158,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() @@ -3210,7 +3218,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3219,11 +3227,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() @@ -3282,7 +3289,7 @@ func TestMain(m *testing.M) { } func Test_GetCreateAccountByPrivateDomain(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -3328,7 +3335,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) { } func Test_UpdateToPrimaryAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -3358,7 +3365,7 @@ func Test_UpdateToPrimaryAccount(t *testing.T) { } func TestDefaultAccountManager_IsCacheCold(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) t.Run("memory cache", func(t *testing.T) { @@ -3408,7 +3415,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { } func TestPropagateUserGroupMemberships(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) ctx := context.Background() @@ -3525,7 +3532,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { } func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") @@ -3557,7 +3564,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { } func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") @@ -3596,7 +3603,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { } func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -3663,7 +3670,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { } func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -3709,7 +3716,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { } func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/dns.go b/management/server/dns.go index decc5175d..baf6debc3 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -3,54 +3,23 @@ package server import ( "context" "slices" - "sync" log "github.com/sirupsen/logrus" - "golang.org/x/mod/semver" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) const ( dnsForwarderPort = nbdns.ForwarderServerPort - oldForwarderPort = nbdns.ForwarderClientPort ) -const dnsForwarderPortMinVersion = "v0.59.0" - -// DNSConfigCache is a thread-safe cache for DNS configuration components -type DNSConfigCache struct { - NameServerGroups sync.Map -} - -// GetNameServerGroup retrieves a cached name server group -func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { - if c == nil { - return nil, false - } - if value, ok := c.NameServerGroups.Load(key); ok { - return value.(*proto.NameServerGroup), true - } - return nil, false -} - -// SetNameServerGroup stores a name server group in the cache -func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) { - if c == nil { - return - } - c.NameServerGroups.Store(key, value) -} - // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) @@ -117,9 +86,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -194,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID return validateGroups(settings.DisabledManagementGroups, groups) } - -// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. -// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. -func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { - if len(peers) == 0 { - return int64(oldForwarderPort) - } - - reqVer := semver.Canonical(requiredVersion) - - // Check if all peers have the required version or newer - for _, peer := range peers { - - // Development version is always supported - if peer.Meta.WtVersion == "development" { - continue - } - peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) - if peerVersion == "" { - // If any peer doesn't have version info, return 0 - return int64(oldForwarderPort) - } - - // Compare versions - if semver.Compare(peerVersion, reqVer) < 0 { - return int64(oldForwarderPort) - } - } - - // All peers have the required version or newer - return int64(dnsForwarderPort) -} - -// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache -func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig { - protoUpdate := &proto.DNSConfig{ - ServiceEnable: update.ServiceEnable, - CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), - NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), - ForwarderPort: forwardPort, - } - - for _, zone := range update.CustomZones { - protoZone := convertToProtoCustomZone(zone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) - } - - for _, nsGroup := range update.NameServerGroups { - cacheKey := nsGroup.ID - if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists { - protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) - } else { - protoGroup := convertToProtoNameServerGroup(nsGroup) - cache.SetNameServerGroup(cacheKey, protoGroup) - protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) - } - } - - return protoUpdate -} - -// Helper function to convert nbdns.CustomZone to proto.CustomZone -func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { - protoZone := &proto.CustomZone{ - Domain: zone.Domain, - Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), - } - for _, record := range zone.Records { - protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ - Name: record.Name, - Type: int64(record.Type), - Class: record.Class, - TTL: int64(record.TTL), - RData: record.RData, - }) - } - return protoZone -} - -// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup -func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { - protoGroup := &proto.NameServerGroup{ - Primary: nsGroup.Primary, - Domains: nsGroup.Domains, - SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, - NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), - } - for _, ns := range nsGroup.NameServers { - protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ - IP: ns.IP.String(), - Port: int64(ns.Port), - NSType: int64(ns.NSType), - }) - } - return protoGroup -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 96f73a390..356a2f640 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -2,9 +2,7 @@ package server import ( "context" - "fmt" "net/netip" - "reflect" "testing" "time" @@ -12,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -218,7 +218,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { // return empty extra settings for expected calls to UpdateAccountPeers settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock()) + + return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createDNSStore(t *testing.T) (store.Store, error) { @@ -344,247 +350,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return am.Store.GetAccount(context.Background(), account.Id) } -func generateTestData(size int) nbdns.Config { - config := nbdns.Config{ - ServiceEnable: true, - CustomZones: make([]nbdns.CustomZone, size), - NameServerGroups: make([]*nbdns.NameServerGroup, size), - } - - for i := 0; i < size; i++ { - config.CustomZones[i] = nbdns.CustomZone{ - Domain: fmt.Sprintf("domain%d.com", i), - Records: []nbdns.SimpleRecord{ - { - Name: fmt.Sprintf("record%d", i), - Type: 1, - Class: "IN", - TTL: 3600, - RData: "192.168.1.1", - }, - }, - } - - config.NameServerGroups[i] = &nbdns.NameServerGroup{ - ID: fmt.Sprintf("group%d", i), - Primary: i == 0, - Domains: []string{fmt.Sprintf("domain%d.com", i)}, - SearchDomainsEnabled: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("8.8.8.8"), - Port: 53, - NSType: 1, - }, - }, - } - } - - return config -} - -func BenchmarkToProtocolDNSConfig(b *testing.B) { - sizes := []int{10, 100, 1000} - - for _, size := range sizes { - testData := generateTestData(size) - - b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) { - cache := &DNSConfigCache{} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) - } - }) - - b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) - } - }) - } -} - -func TestToProtocolDNSConfigWithCache(t *testing.T) { - var cache DNSConfigCache - - // Create two different configs - config1 := nbdns.Config{ - ServiceEnable: true, - CustomZones: []nbdns.CustomZone{ - { - Domain: "example.com", - Records: []nbdns.SimpleRecord{ - {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"}, - }, - }, - }, - NameServerGroups: []*nbdns.NameServerGroup{ - { - ID: "group1", - Name: "Group 1", - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.8.8"), Port: 53}, - }, - }, - }, - } - - config2 := nbdns.Config{ - ServiceEnable: true, - CustomZones: []nbdns.CustomZone{ - { - Domain: "example.org", - Records: []nbdns.SimpleRecord{ - {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"}, - }, - }, - }, - NameServerGroups: []*nbdns.NameServerGroup{ - { - ID: "group2", - Name: "Group 2", - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.4.4"), Port: 53}, - }, - }, - }, - } - - // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) - - // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) - - // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) - - // Verify that result1 and result3 are identical - if !reflect.DeepEqual(result1, result3) { - t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3) - } - - // Verify that result2 is different from result1 and result3 - if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) { - t.Errorf("Results should be different for different inputs") - } - - if _, exists := cache.GetNameServerGroup("group1"); !exists { - t.Errorf("Cache should contain name server group 'group1'") - } - - if _, exists := cache.GetNameServerGroup("group2"); !exists { - t.Errorf("Cache should contain name server group 'group2'") - } -} - -func TestComputeForwarderPort(t *testing.T) { - // Test with empty peers list - peers := []*nbpeer.Peer{} - result := computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) - } - - // Test with peers that have old versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.57.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.26.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) - } - - // Test with peers that have new versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(dnsForwarderPort) { - t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) - } - - // Test with peers that have mixed versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.57.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) - } - - // Test with peers that have empty version - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) - } - - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "development", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result == int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) - } - - // Test with peers that have unknown version string - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "unknown", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) - } -} - func TestDNSAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{ { @@ -600,9 +367,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates diff --git a/management/server/event_test.go b/management/server/event_test.go index 8c56fd3f6..420e69866 100644 --- a/management/server/event_test.go +++ b/management/server/event_test.go @@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac } func TestDefaultAccountManager_GetEvents(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { return } diff --git a/management/server/group.go b/management/server/group.go index 3cf9290a2..84e641f26 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -114,9 +114,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -185,9 +182,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -256,9 +250,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -327,9 +318,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -376,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) return nil } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) for _, peerID := range addedPeers { peer, ok := peers[peerID] @@ -493,9 +481,6 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -534,9 +519,6 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -565,9 +547,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -606,9 +585,6 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/group_test.go b/management/server/group_test.go index 31ff29cbc..4935dac5d 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -37,7 +37,7 @@ const ( ) func TestDefaultAccountManager_CreateGroup(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Error("failed to create account manager") } @@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { } func TestDefaultAccountManager_DeleteGroup(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Fatalf("failed to create account manager: %s", err) } @@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) { } func TestDefaultAccountManager_DeleteGroups(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) assert.NoError(t, err, "Failed to create account manager") manager, account, err := initTestGroupAccount(am) @@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t } func TestGroupAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Saving a group that is not linked to any resource should not update account peers @@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { } func Test_AddPeerToGroup(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) { } func Test_AddPeerToAll(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) { } func Test_AddPeerAndAddToAll(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP { } func Test_IncrementNetworkSerial(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return diff --git a/management/server/holder.go b/management/server/holder.go deleted file mode 100644 index e8a26e1d0..000000000 --- a/management/server/holder.go +++ /dev/null @@ -1,39 +0,0 @@ -package server - -import ( - "github.com/netbirdio/netbird/management/server/types" -) - -func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) { - a := am.holder.GetAccount(account.Id) - if a == nil { - am.holder.AddAccount(account) - return - } - account.NetworkMapCache = a.NetworkMapCache - if account.NetworkMapCache == nil { - return - } - account.NetworkMapCache.UpdateAccountPointer(account) - am.holder.AddAccount(account) -} - -func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account { - return am.holder.GetAccount(accountID) -} - -func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account { - a := am.holder.GetAccount(accountID) - if a != nil { - return a - } - account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure) - if err != nil { - return nil - } - return account -} - -func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) { - am.holder.AddAccount(account) -} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 4d2c224b4..c1a8c5885 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -13,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/settings" @@ -65,6 +66,7 @@ func NewAPIHandler( permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, + networkMapController network_map.Controller, ) (http.Handler, error) { var rateLimitingConfig *middleware.RateLimiterConfig @@ -120,7 +122,7 @@ func NewAPIHandler( } accounts.AddEndpoints(accountManager, settingsManager, router) - peers.AddEndpoints(accountManager, router) + peers.AddEndpoints(accountManager, router, networkMapController) users.AddEndpoints(accountManager, router) setup_keys.AddEndpoints(accountManager, router) policies.AddEndpoints(accountManager, LocationManager, router) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index df89c616c..c4c5ae165 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -23,11 +24,12 @@ import ( // Handler is a handler that returns peers of the account type Handler struct { - accountManager account.Manager + accountManager account.Manager + networkMapController network_map.Controller } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { - peersHandler := NewHandler(accountManager) +func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) { + peersHandler := NewHandler(accountManager, networkMapController) router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") @@ -36,9 +38,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { } // NewHandler creates a new peers Handler -func NewHandler(accountManager account.Manager) *Handler { +func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler { return &Handler{ - accountManager: accountManager, + accountManager: accountManager, + networkMapController: networkMapController, } } @@ -47,7 +50,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { if peer.Status.Connected { // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected // This may happen after server restart when not all peers are yet connected - if !h.accountManager.HasConnectedChannel(peer.ID) { + if !h.networkMapController.IsConnected(peer.ID) { peerToReturn.Status.Connected = false } } @@ -73,7 +76,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) @@ -139,7 +142,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) if err != nil { @@ -227,7 +230,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) @@ -317,7 +320,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - dnsDomain := h.accountManager.GetDNSDomain(account.Settings) + dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 94564113f..7a5a6d911 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -14,12 +14,14 @@ import ( "time" "github.com/gorilla/mux" + "go.uber.org/mock/gomock" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,7 +38,7 @@ const ( serviceUser = "service_user" ) -func initTestMetaData(peers ...*nbpeer.Peer) *Handler { +func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { @@ -99,6 +101,22 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { }, } + ctrl := gomock.NewController(t) + + networkMapController := network_map.NewMockController(ctrl) + networkMapController.EXPECT(). + GetDNSDomain(gomock.Any()). + Return("domain"). + AnyTimes() + networkMapController.EXPECT(). + IsConnected(noUpdateChannelTestPeerID). + Return(false). + AnyTimes() + networkMapController.EXPECT(). + IsConnected(gomock.Any()). + Return(true). + AnyTimes() + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -187,6 +205,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { return account.Settings, nil }, }, + networkMapController: networkMapController, } } @@ -270,7 +289,7 @@ func TestGetPeers(t *testing.T) { rr := httptest.NewRecorder() - p := initTestMetaData(peer, peer1) + p := initTestMetaData(t, peer, peer1) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -374,7 +393,7 @@ func TestGetAccessiblePeers(t *testing.T) { UserID: regularUser, } - p := initTestMetaData(peer1, peer2, peer3) + p := initTestMetaData(t, peer1, peer2, peer3) tt := []struct { name string @@ -477,7 +496,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { }, } - p := initTestMetaData(testPeer) + p := initTestMetaData(t, testPeer) tt := []struct { name string diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index bdf56db6e..ab3f5437a 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -10,6 +10,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" @@ -31,7 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/users" ) -func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { +func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) if err != nil { t.Fatalf("Failed to create test store: %v", err) @@ -43,7 +47,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee t.Fatalf("Failed to create metrics: %v", err) } - peersUpdateManager := server.NewPeersUpdateManager(nil) + peersUpdateManager := update_channel.NewPeersUpdateManager(nil) updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) done := make(chan struct{}) if validateUpdate { @@ -63,7 +67,11 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) - am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + + ctx := context.Background() + requestBuffer := server.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock()) + am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) } @@ -83,7 +91,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee groupsManagerMock := groups.NewManagerMock() peersManager := peers.NewManager(store, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -91,7 +99,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee return apiHandler, am, done } -func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) { +func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { case msg := <-updateMessage: @@ -101,7 +109,7 @@ func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server } } -func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { +func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage, expected *network_map.UpdateMessage) { t.Helper() select { diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index a34d2086b..fc67e01af 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -22,10 +22,14 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -321,99 +325,6 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ return loginResp, nil } -func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { - testingServerKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err) - } - - testingClientKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err) - } - - testCases := []struct { - name string - inputFlow *config.DeviceAuthorizationFlow - expectedFlow *mgmtProto.DeviceAuthorizationFlow - expectedErrFunc require.ErrorAssertionFunc - expectedErrMSG string - expectedComparisonFunc require.ComparisonAssertionFunc - expectedComparisonMSG string - }{ - { - name: "Testing No Device Flow Config", - inputFlow: nil, - expectedErrFunc: require.Error, - expectedErrMSG: "should return error", - }, - { - name: "Testing Invalid Device Flow Provider Config", - inputFlow: &config.DeviceAuthorizationFlow{ - Provider: "NoNe", - ProviderConfig: config.ProviderConfig{ - ClientID: "test", - }, - }, - expectedErrFunc: require.Error, - expectedErrMSG: "should return error", - }, - { - name: "Testing Full Device Flow Config", - inputFlow: &config.DeviceAuthorizationFlow{ - Provider: "hosted", - ProviderConfig: config.ProviderConfig{ - ClientID: "test", - }, - }, - expectedFlow: &mgmtProto.DeviceAuthorizationFlow{ - Provider: 0, - ProviderConfig: &mgmtProto.ProviderConfig{ - ClientID: "test", - }, - }, - expectedErrFunc: require.NoError, - expectedErrMSG: "should not return error", - expectedComparisonFunc: require.Equal, - expectedComparisonMSG: "should match", - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - mgmtServer := &GRPCServer{ - wgKey: testingServerKey, - config: &config.Config{ - DeviceAuthorizationFlow: testCase.inputFlow, - }, - } - - message := &mgmtProto.DeviceAuthorizationFlowRequest{} - - encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) - require.NoError(t, err, "should be able to encrypt message") - - resp, err := mgmtServer.GetDeviceAuthorizationFlow( - context.TODO(), - &mgmtProto.EncryptedMessage{ - WgPubKey: testingClientKey.PublicKey().String(), - Body: encryptedMSG, - }, - ) - testCase.expectedErrFunc(t, err, testCase.expectedErrMSG) - if testCase.expectedComparisonFunc != nil { - flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} - - err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) - require.NoError(t, err, "should be able to decrypt") - - testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) - testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG) - } - }) - } -} - func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") @@ -427,7 +338,6 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config t.Fatal(err) } - peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck @@ -451,7 +361,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config permissionsManager := permissions.NewManager(store) groupsManager := groups.NewManagerMock() - accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { @@ -459,10 +372,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) ephemeralMgr := manager.NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController) if err != nil { return nil, nil, "", cleanup, err } @@ -764,9 +677,38 @@ func Test_LoginPerformance(t *testing.T) { peerLogin := types.PeerLogin{ WireGuardPubKey: key.String(), SSHKey: "random", - Meta: extractPeerMeta(context.Background(), meta), - SetupKey: setupKey.Key, - ConnectionIP: net.IP{1, 1, 1, 1}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: meta.GetHostname(), + GoOS: meta.GetGoOS(), + Kernel: meta.GetKernel(), + Platform: meta.GetPlatform(), + OS: meta.GetOS(), + OSVersion: meta.GetOSVersion(), + WtVersion: meta.GetNetbirdVersion(), + UIVersion: meta.GetUiVersion(), + KernelVersion: meta.GetKernelVersion(), + SystemSerialNumber: meta.GetSysSerialNumber(), + SystemProductName: meta.GetSysProductName(), + SystemManufacturer: meta.GetSysManufacturer(), + Environment: nbpeer.Environment{ + Cloud: meta.GetEnvironment().GetCloud(), + Platform: meta.GetEnvironment().GetPlatform(), + }, + Flags: nbpeer.Flags{ + RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(), + RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(), + ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(), + DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(), + DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(), + DisableDNS: meta.GetFlags().GetDisableDNS(), + DisableFirewall: meta.GetFlags().GetDisableFirewall(), + BlockLANAccess: meta.GetFlags().GetBlockLANAccess(), + BlockInbound: meta.GetFlags().GetBlockInbound(), + LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(), + }, + }, + SetupKey: setupKey.Key, + ConnectionIP: net.IP{1, 1, 1, 1}, } login := func() error { diff --git a/management/server/management_test.go b/management/server/management_test.go index 1a5e47354..930ecfb5a 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -20,7 +20,10 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" @@ -176,7 +179,6 @@ func startServer( log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) @@ -199,13 +201,18 @@ func startServer( AnyTimes() permissionsManager := permissions.NewManager(str) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := server.NewAccountRequestBuffer(ctx, str) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := server.BuildManager( context.Background(), str, - peersUpdateManager, + networkMapController, nil, "", - "netbird.selfhosted", eventStore, nil, false, @@ -220,18 +227,18 @@ func startServer( } groupsManager := groups.NewManager(str, permissionsManager, accountManager) - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer( - context.Background(), + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer( config, accountManager, settingsMockManager, - peersUpdateManager, + updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, + networkMapController, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8baffa58b..781d84f5f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -38,7 +38,7 @@ type MockAccountManager struct { ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) @@ -94,7 +94,7 @@ type MockAccountManager struct { GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error @@ -178,11 +178,11 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") + return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { @@ -747,11 +747,11 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(ctx, sync, accountID) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") + return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } // GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface diff --git a/management/server/nameserver.go b/management/server/nameserver.go index ee77a65bb..f278e1761 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -83,9 +83,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -137,9 +134,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -183,9 +177,6 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 6c985410c..35291b30c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -785,7 +787,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + + return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createNSStore(t *testing.T) (store.Store, error) { @@ -975,7 +983,7 @@ func TestValidateDomain(t *testing.T) { } func TestNameServerAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup @@ -994,9 +1002,9 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Creating a nameserver group with a distribution group no peers should not update account peers diff --git a/management/server/networkmap.go b/management/server/networkmap.go deleted file mode 100644 index 2a0627643..000000000 --- a/management/server/networkmap.go +++ /dev/null @@ -1,80 +0,0 @@ -package server - -import ( - "context" - - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - - nbdns "github.com/netbirdio/netbird/dns" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" -) - -func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { - am.enrichAccountFromHolder(account) - account.InitNetworkMapBuilderIfNeeded(validatedPeers) -} - -func (am *DefaultAccountManager) getPeerNetworkMapExp( - ctx context.Context, - accountId string, - peerId string, - validatedPeers map[string]struct{}, - customZone nbdns.CustomZone, - metrics *telemetry.AccountManagerMetrics, -) *types.NetworkMap { - account := am.getAccountFromHolderOrInit(accountId) - if account == nil { - log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) - return &types.NetworkMap{ - Network: &types.Network{}, - } - } - return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) -} - -func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { - am.enrichAccountFromHolder(account) - return account.OnPeerAddedUpdNetworkMapCache(peerId) -} - -func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { - am.enrichAccountFromHolder(account) - return account.OnPeerDeletedUpdNetworkMapCache(peerId) -} - -func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { - account := am.getAccountFromHolder(accountId) - if account == nil { - return - } - account.UpdatePeerInNetworkMapCache(peer) -} - -func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { - account.RecalculateNetworkMapCache(validatedPeers) - am.updateAccountInHolder(account) -} - -func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { - if am.experimentalNetworkMap(accountId) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - return err - } - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) - return err - } - am.recalculateNetworkMapCache(account, validatedPeers) - } - return nil -} - -func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool { - _, ok := am.expNewNetworkMapAIDs[accountId] - return am.expNewNetworkMap || ok -} diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index 0e6d1631b..b6706ca45 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -177,9 +177,6 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index b740610c2..66484d120 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -157,9 +157,6 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -260,9 +257,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -337,9 +331,6 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 89ac419fd..82cac424a 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -119,9 +119,6 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) - if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -186,9 +183,6 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) - if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -223,9 +217,6 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() - if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/peer.go b/management/server/peer.go index 4c605b5eb..cd9fbe4c8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -8,8 +8,6 @@ import ( "net" "slices" "strings" - "sync" - "sync/atomic" "time" "github.com/rs/xid" @@ -23,7 +21,6 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -31,7 +28,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -140,12 +136,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - // we need to update other peers because when peer login expires all other peers are notified to disconnect from - // the expired one. Here we notify them that connection is now allowed again. - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.OnPeerUpdated(accountID, peer) } return nil @@ -201,7 +192,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var peer *nbpeer.Peer var settings *types.Settings var peerGroupList []string - var requiresPeerUpdates bool var peerLabelChanged bool var sshChanged bool var loginExpirationChanged bool @@ -224,9 +214,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - dnsDomain = am.GetDNSDomain(settings) + dnsDomain = am.networkMapController.GetDNSDomain(settings) - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) + update, _, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) if err != nil { return err } @@ -319,15 +309,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - - if peerLabelChanged || requiresPeerUpdates { - am.UpdateAccountPeers(ctx, accountID) - } else if sshChanged { - am.UpdateAccountPeer(ctx, accountID, peer.ID) - } + am.networkMapController.OnPeerUpdated(accountID, peer) return peer, nil } @@ -383,20 +365,13 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if am.experimentalNetworkMap(accountID) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } - - if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) - } - + err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) } - if userID != activity.SystemInitiator { - am.BufferUpdateAccountPeers(ctx, accountID) + if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) } return nil @@ -404,47 +379,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { - account, err := am.Store.GetAccountByPeerID(ctx, peerID) - if err != nil { - return nil, err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) - } - - groups := make(map[string][]string) - for groupID, group := range account.Groups { - groups[groupID] = group.Peers - } - - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, err - } - customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return nil, err - } - - var networkMap *types.NetworkMap - - if am.experimentalNetworkMap(peer.AccountID) { - networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) - } - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - networkMap.Merge(proxyNetworkMap) - } - - return networkMap, nil + return am.networkMapController.GetNetworkMap(ctx, peerID) } // GetPeerNetwork returns the Network for a given peer @@ -703,27 +638,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) + opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings)) if !addedByUser { opEvent.Meta["setup_key_name"] = setupKeyName } am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if am.experimentalNetworkMap(accountID) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) - } + if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) } - am.BufferUpdateAccountPeers(ctx, accountID) - - return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer) + return p, nmap, pc, err } func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { @@ -738,7 +665,7 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool @@ -748,7 +675,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -798,17 +725,14 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil }) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.OnPeerUpdated(accountID, peer) } - return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) + return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { @@ -933,15 +857,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - startBuffer := time.Now() - am.BufferUpdateAccountPeers(ctx, accountID) - log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer)) + am.networkMapController.OnPeerUpdated(accountID, peer) } - return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) + return p, nmap, pc, err } // getPeerPostureChecks returns the posture checks for the peer. @@ -1033,68 +953,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - if isRequiresApproval { - network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, nil, nil, err - } - - emptyMap := &types.NetworkMap{ - Network: network.Copy(), - } - return peer, emptyMap, nil, nil - } - - var ( - account *types.Account - err error - ) - if am.experimentalNetworkMap(accountID) { - account = am.getAccountFromHolderOrInit(accountID) - } else { - account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, nil, nil, err - } - - startPosture := time.Now() - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) - if err != nil { - return nil, nil, nil, err - } - log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) - - customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return nil, nil, nil, err - } - - var networkMap *types.NetworkMap - - if am.experimentalNetworkMap(accountID) { - networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) - } - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - networkMap.Merge(proxyNetworkMap) - } - - return peer, networkMap, postureChecks, nil -} - func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { @@ -1118,7 +976,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact return fmt.Errorf("failed to get account settings: %w", err) } - am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings))) + am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.networkMapController.GetDNSDomain(settings))) return nil } @@ -1214,232 +1072,17 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - var ( - account *types.Account - err error - ) - if am.experimentalNetworkMap(accountID) { - account = am.getAccountFromHolderOrInit(accountID) - } else { - account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) - return - } - } - - globalStart := time.Now() - - hasPeersConnected := false - for _, peer := range account.Peers { - if am.peersUpdateManager.HasChannel(peer.ID) { - hasPeersConnected = true - break - } - - } - - if !hasPeersConnected { - return - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) - return - } - - var wg sync.WaitGroup - semaphore := make(chan struct{}, 10) - - dnsCache := &DNSConfigCache{} - dnsDomain := am.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - if am.experimentalNetworkMap(accountID) { - am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) - } - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return - } - - extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) - return - } - - dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) - - for _, peer := range account.Peers { - if !am.peersUpdateManager.HasChannel(peer.ID) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) - continue - } - - wg.Add(1) - semaphore <- struct{}{} - go func(p *nbpeer.Peer) { - defer wg.Done() - defer func() { <-semaphore }() - - start := time.Now() - - postureChecks, err := am.getPeerPostureChecks(account, p.ID) - if err != nil { - log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err) - return - } - - am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start)) - start = time.Now() - - var remotePeerNetworkMap *types.NetworkMap - - if am.experimentalNetworkMap(accountID) { - remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) - } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) - } - - am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start)) - start = time.Now() - - proxyNetworkMap, ok := proxyNetworkMaps[p.ID] - if ok { - remotePeerNetworkMap.Merge(proxyNetworkMap) - } - am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) - - peerGroups := account.GetPeerGroups(p.ID) - start = time.Now() - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) - am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) - - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) - }(peer) - } - - // - - wg.Wait() - if am.metrics != nil { - am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart)) - } -} - -type bufferUpdate struct { - mu sync.Mutex - next *time.Timer - update atomic.Bool + _ = am.networkMapController.UpdateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) - - bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) - b := bufUpd.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - am.UpdateAccountPeers(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - if b.next == nil { - b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() { - am.UpdateAccountPeers(ctx, accountID) - }) - return - } - b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load())) - }() + _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID) } // UpdateAccountPeer updates a single peer that belongs to an account. // Should be called when changes need to be synced to a specific peer only. func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { - if !am.peersUpdateManager.HasChannel(peerId) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) - return - } - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err) - return - } - - peer := account.GetPeer(peerId) - if peer == nil { - log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId) - return - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) - return - } - - dnsCache := &DNSConfigCache{} - dnsDomain := am.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - postureChecks, err := am.getPeerPostureChecks(account, peerId) - if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) - return - } - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return - } - - var remotePeerNetworkMap *types.NetworkMap - - if am.experimentalNetworkMap(accountId) { - remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) - } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) - } - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - remotePeerNetworkMap.Merge(proxyNetworkMap) - } - - extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get extra settings: %v", err) - return - } - - peerGroups := account.GetPeerGroups(peerId) - dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) - - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) + _ = am.networkMapController.UpdateAccountPeer(ctx, accountId, peerId) } // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. @@ -1594,14 +1237,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err != nil { return nil, err } - dnsDomain := am.GetDNSDomain(settings) - - network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, err - } - - dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion) + dnsDomain := am.networkMapController.GetDNSDomain(settings) for _, peer := range peers { if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { @@ -1635,24 +1271,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } - - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{ - Update: &proto.SyncResponse{ - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - NetworkMap: &proto.NetworkMap{ - Serial: network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - DNSConfig: &proto.DNSConfig{ - ForwarderPort: dnsFwdPort, - }, - }, - }, - }) - am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) }) @@ -1661,14 +1279,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return peerDeletedEvents, nil } -func ConvertSliceToMap(existingLabels []string) map[string]struct{} { - labelMap := make(map[string]struct{}, len(existingLabels)) - for _, label := range existingLabels { - labelMap[label] = struct{}{} - } - return labelMap -} - // validatePeerDelete checks if the peer can be deleted. func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error { linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index e151f5abb..95c609595 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,7 +13,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "testing" "time" @@ -25,10 +24,14 @@ import ( "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/shared/management/status" @@ -172,12 +175,12 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { } func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testGetNetworkMapGeneral(t) } func testGetNetworkMapGeneral(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -249,7 +252,7 @@ func testGetNetworkMapGeneral(t *testing.T) { func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { // TODO: disable until we start use policy again t.Skip() - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -426,7 +429,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } func TestAccountManager_GetPeerNetwork(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -487,7 +490,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { } func TestDefaultAccountManager_GetPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -674,7 +677,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -742,12 +745,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } } -func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) { +func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, *update_channel.PeersUpdateManager, string, string, error) { b.Helper() - manager, err := createManager(b) + manager, updateManager, err := createManager(b) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } accountID := "test_account" @@ -798,7 +801,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou ips := account.GetTakenIPs() peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } peerKey, _ := wgtypes.GeneratePrivateKey() @@ -904,10 +907,10 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou err = manager.Store.SaveAccount(context.Background(), account) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } - return manager, accountID, regularUser, nil + return manager, updateManager, accountID, regularUser, nil } func BenchmarkGetPeers(b *testing.B) { @@ -928,7 +931,7 @@ func BenchmarkGetPeers(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, _, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -968,7 +971,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -980,14 +983,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) - for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels - b.ResetTimer() start := time.Now() @@ -1013,7 +1012,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } func TestUpdateAccountPeers_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testUpdateAccountPeers(t) } @@ -1037,7 +1036,7 @@ func testUpdateAccountPeers(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) if err != nil { t.Fatalf("Failed to setup test account manager: %v", err) } @@ -1049,13 +1048,12 @@ func testUpdateAccountPeers(t *testing.T) { t.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + peerChannels := make(map[string]chan *network_map.UpdateMessage) for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels manager.UpdateAccountPeers(ctx, account.Id) for _, channel := range peerChannels { @@ -1097,7 +1095,7 @@ func TestToSyncResponse(t *testing.T) { DNSLabel: "peer1", SSHKey: "peer1-ssh-key", } - turnRelayToken := &Token{ + turnRelayToken := &grpc.Token{ Payload: "turn-user", Signature: "turn-pass", } @@ -1177,9 +1175,9 @@ func TestToSyncResponse(t *testing.T) { }, }, } - dnsCache := &DNSConfigCache{} + dnsCache := &cache.DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) + response := grpc.ToSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config @@ -1289,7 +1287,12 @@ func Test_RegisterPeerByUser(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1369,7 +1372,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1517,7 +1525,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1566,7 +1579,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } @@ -1592,7 +1605,12 @@ func Test_LoginPeer(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1725,7 +1743,7 @@ func Test_LoginPeer(t *testing.T) { } func TestPeerAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) @@ -1782,13 +1800,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) { var peer5 *nbpeer.Peer var peer6 *nbpeer.Peer - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { + t.Skip("Currently all updates will trigger a network map") done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) @@ -1890,6 +1909,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) { }) t.Run("validator requires no update", func(t *testing.T) { + t.Skip("Currently all updates will trigger a network map") + requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { return update, false, nil } @@ -2091,7 +2112,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { } func Test_DeletePeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -2188,7 +2209,7 @@ func Test_IsUniqueConstraintError(t *testing.T) { } func Test_AddPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -2276,136 +2297,8 @@ func Test_AddPeer(t *testing.T) { assert.Equal(t, uint64(totalPeers), account.Network.Serial) } -func TestBufferUpdateAccountPeers(t *testing.T) { - const ( - peersCount = 1000 - updateAccountInterval = 50 * time.Millisecond - ) - - var ( - deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 - uapLastRun, dpLastRun atomic.Int64 - - totalNewRuns, totalOldRuns int - ) - - uap := func(ctx context.Context, accountID string) { - updatePeersDeleted.Store(deletedPeers.Load()) - updatePeersRuns.Add(1) - uapLastRun.Store(time.Now().UnixMilli()) - time.Sleep(100 * time.Millisecond) - } - - t.Run("new approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) - b := mu.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - uap(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - b.next = time.AfterFunc(updateAccountInterval, func() { - uap(ctx, accountID) - }) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalNewRuns = int(updatePeersRuns.Load()) - }) - - t.Run("old approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) - b := mu.(*sync.Mutex) - - if !b.TryLock() { - return - } - - go func() { - time.Sleep(updateAccountInterval) - b.Unlock() - uap(ctx, accountID) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalOldRuns = int(updatePeersRuns.Load()) - }) - assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) - t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) -} - func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2442,7 +2335,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { } func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2476,7 +2369,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { } func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2541,7 +2434,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { } func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/policy.go b/management/server/policy.go index ff02d46aa..3e84c3d10 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -10,7 +10,6 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" @@ -77,9 +76,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -123,9 +119,6 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -258,31 +251,3 @@ func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []strin return validIDs } - -// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - fwRule := &proto.FirewallRule{ - PolicyID: []byte(rule.PolicyID), - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, - } - - if shouldUsePortRange(fwRule) { - fwRule.PortInfo = rule.PortRange.ToProto() - } - - result[i] = fwRule - } - return result -} - -func shouldUsePortRange(rule *proto.FirewallRule) bool { - return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP) -} diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 97ebbcf5a..90fe8f036 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -1135,7 +1135,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { } func TestPolicyAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -1164,9 +1164,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) var policyWithGroupRulesNoPeers *types.Policy diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index f457b994b..ac8ea35de 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,19 +2,15 @@ package server import ( "context" - "errors" - "fmt" "slices" "github.com/rs/xid" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -80,9 +76,6 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -139,27 +132,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID) } -// getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { - peerPostureChecks := make(map[string]*posture.Checks) - - if len(account.PostureChecks) == 0 { - return nil, nil - } - - for _, policy := range account.Policies { - if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { - continue - } - - if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { - return nil, err - } - } - - return maps.Values(peerPostureChecks), nil -} - // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) @@ -214,50 +186,6 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account return nil } -// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { - isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) - if err != nil { - return err - } - - if !isInGroup { - return nil - } - - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck := account.GetPostureChecks(sourcePostureCheckID) - if postureCheck == nil { - return errors.New("failed to add policy posture checks: posture checks not found") - } - peerPostureChecks[sourcePostureCheckID] = postureCheck - } - - return nil -} - -// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - for _, sourceGroup := range rule.Sources { - group := account.GetGroup(sourceGroup) - if group == nil { - return false, fmt.Errorf("failed to check peer in policy source group: group not found") - } - - if slices.Contains(group.Peers, peerID) { - return true, nil - } - } - } - - return false, nil -} - // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 67760d55a..13152ed12 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -21,7 +21,7 @@ const ( ) func TestDefaultAccountManager_PostureCheck(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Error("failed to create account manager") } @@ -123,7 +123,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er } func TestPostureCheckAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -147,9 +147,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) postureCheckA := &posture.Checks{ @@ -359,9 +359,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy where destination has peers but source does not // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) { - updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + updMsg1 := updateManager.CreateChannel(context.Background(), peer2.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + updateManager.CloseChannel(context.Background(), peer2.ID) }) _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ @@ -445,7 +445,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } func TestArePostureCheckChangesAffectPeers(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "failed to create account manager") account, err := initTestPostureChecksAccount(manager) diff --git a/management/server/route.go b/management/server/route.go index 05f7acf9e..2b4f11d05 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -16,7 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -192,9 +191,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -249,9 +245,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -295,9 +288,6 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -381,103 +371,12 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID return groupsMap, nil } -func toProtocolRoute(route *route.Route) *proto.Route { - return &proto.Route{ - ID: string(route.ID), - NetID: string(route.NetID), - Network: route.Network.String(), - Domains: route.Domains.ToPunycodeList(), - NetworkType: int64(route.NetworkType), - Peer: route.Peer, - Metric: int64(route.Metric), - Masquerade: route.Masquerade, - KeepRoute: route.KeepRoute, - SkipAutoApply: route.SkipAutoApply, - } -} - -func toProtocolRoutes(routes []*route.Route) []*proto.Route { - protoRoutes := make([]*proto.Route, 0, len(routes)) - for _, r := range routes { - protoRoutes = append(protoRoutes, toProtocolRoute(r)) - } - return protoRoutes -} - // getPlaceholderIP returns a placeholder IP address for the route if domains are used func getPlaceholderIP() netip.Prefix { // Using an IP from the documentation range to minimize impact in case older clients try to set a route return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { - result := make([]*proto.RouteFirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - result[i] = &proto.RouteFirewallRule{ - SourceRanges: rule.SourceRanges, - Action: getProtoAction(rule.Action), - Destination: rule.Destination, - Protocol: getProtoProtocol(rule.Protocol), - PortInfo: getProtoPortInfo(rule), - IsDynamic: rule.IsDynamic, - Domains: rule.Domains.ToPunycodeList(), - PolicyID: []byte(rule.PolicyID), - RouteID: string(rule.RouteID), - } - } - - return result -} - -// getProtoDirection converts the direction to proto.RuleDirection. -func getProtoDirection(direction int) proto.RuleDirection { - if direction == types.FirewallRuleDirectionOUT { - return proto.RuleDirection_OUT - } - return proto.RuleDirection_IN -} - -// getProtoAction converts the action to proto.RuleAction. -func getProtoAction(action string) proto.RuleAction { - if action == string(types.PolicyTrafficActionDrop) { - return proto.RuleAction_DROP - } - return proto.RuleAction_ACCEPT -} - -// getProtoProtocol converts the protocol to proto.RuleProtocol. -func getProtoProtocol(protocol string) proto.RuleProtocol { - switch types.PolicyRuleProtocolType(protocol) { - case types.PolicyRuleProtocolALL: - return proto.RuleProtocol_ALL - case types.PolicyRuleProtocolTCP: - return proto.RuleProtocol_TCP - case types.PolicyRuleProtocolUDP: - return proto.RuleProtocol_UDP - case types.PolicyRuleProtocolICMP: - return proto.RuleProtocol_ICMP - default: - return proto.RuleProtocol_UNKNOWN - } -} - -// getProtoPortInfo converts the port info to proto.PortInfo. -func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { - var portInfo proto.PortInfo - if rule.Port != 0 { - portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} - } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { - portInfo.PortSelection = &proto.PortInfo_Range_{ - Range: &proto.PortInfo_Range{ - Start: uint32(portRange.Start), - End: uint32(portRange.End), - }, - } - } - return &portInfo -} - // areRouteChangesAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers. func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) { diff --git a/management/server/route_test.go b/management/server/route_test.go index 388db140c..27fe033c8 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -432,7 +434,7 @@ func TestCreateRoute(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -922,7 +924,7 @@ func TestSaveRoute(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1024,7 +1026,7 @@ func TestDeleteRoute(t *testing.T) { Enabled: true, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1071,7 +1073,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { AccessControlGroups: []string{routeGroup1}, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1163,7 +1165,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { AccessControlGroups: []string{routeGroup1}, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1250,11 +1252,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") } -func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { +func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) { t.Helper() store, err := createRouterStore(t) if err != nil { - return nil, err + return nil, nil, err } eventStore := &activity.InMemoryEventStore{} @@ -1285,7 +1287,16 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + if err != nil { + return nil, nil, err + } + return am, updateManager, nil } func createRouterStore(t *testing.T) (store.Store, error) { @@ -1948,7 +1959,7 @@ func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFi } func TestRouteAccountPeersUpdate(t *testing.T) { - manager, err := createRouterManager(t) + manager, updateManager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") account, err := initTestRouteAccount(t, manager) @@ -1976,9 +1987,9 @@ func TestRouteAccountPeersUpdate(t *testing.T) { require.NoError(t, err, "failed to create group %s", group.Name) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) + updateManager.CloseChannel(context.Background(), peer1ID) }) // Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index e55b33c94..bc361bbd7 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -18,7 +18,7 @@ import ( ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -93,7 +93,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -198,7 +198,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } func TestGetSetupKeys(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func TestSetupKey_Copy(t *testing.T) { } func TestSetupKeyAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -420,9 +420,9 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) var setupKey *types.SetupKey @@ -465,7 +465,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { } func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/user.go b/management/server/user.go index 66bea314f..be4e491a8 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -965,7 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if err != nil { return err } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) var peerIDs []string for _, peer := range peers { @@ -992,16 +992,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } + am.networkMapController.OnPeerUpdated(accountID, peer) } if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) - am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.DisconnectPeers(ctx, peerIDs) } return nil } @@ -1115,6 +1112,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var addPeerRemovedEvents []func() var updateAccountPeers bool + var userPeers []*nbpeer.Peer var targetUser *types.User var err error @@ -1124,7 +1122,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return fmt.Errorf("failed to get user to delete: %w", err) } - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) + userPeers, err = transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) if err != nil { return fmt.Errorf("failed to get user peers: %w", err) } @@ -1147,6 +1145,17 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return false, err } + for _, peer := range userPeers { + err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) + } + + if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err) + } + } + for _, addPeerRemovedEvent := range addPeerRemovedEvents { addPeerRemovedEvent() } diff --git a/management/server/user_test.go b/management/server/user_test.go index 5920a2a33..69b8c85ee 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1161,7 +1161,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { } func TestDefaultAccountManager_SaveUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1333,7 +1333,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -1357,9 +1357,9 @@ func TestUserAccountPeersUpdate(t *testing.T) { _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Creating a new regular user should not update account peers and not send peer update @@ -1468,9 +1468,9 @@ func TestUserAccountPeersUpdate(t *testing.T) { } }) - peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID) + peer4UpdMsg := updateManager.CreateChannel(context.Background(), peer4.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID) + updateManager.CloseChannel(context.Background(), peer4.ID) }) // deleting user with linked peers should update account peers and send peer update @@ -1748,7 +1748,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { } func TestApproveUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -1807,7 +1807,7 @@ func TestApproveUser(t *testing.T) { } func TestRejectUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d4a9f1823..d3f341529 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,6 +18,9 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" @@ -68,7 +71,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } t.Cleanup(cleanUp) - peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} ctrl := gomock.NewController(t) @@ -111,15 +113,19 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } groupsManager := groups.NewManagerMock() - secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } From 6fb568728fa9d8806e662703f1597fe0e14ce2b6 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 13 Nov 2025 12:51:03 +0100 Subject: [PATCH 071/120] [management] Removed policy posture checks on original peer (#4779) Co-authored-by: crn4 --- management/server/types/networkmapbuilder.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 58f1bfa30..41aaa7fc8 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -257,8 +257,6 @@ func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, ) ([]*nbpeer.Peer, []*FirewallRule) { - ctx := context.Background() - peerID := peer.ID peerGroups := b.cache.peerToGroups[peerID] @@ -275,9 +273,6 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n for _, group := range peerGroups { policies := b.cache.groupToPolicies[group] for _, policy := range policies { - if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { - continue - } rules := b.cache.policyToRules[policy.ID] for _, rule := range rules { var sourcePeers, destinationPeers []*nbpeer.Peer From 27957036c9130caeb1cc490a8a994eb1fcad01dc Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 13 Nov 2025 13:24:51 +0100 Subject: [PATCH 072/120] [client] Fix shutdown blocking on stuck ICE agent close (#4780) --- client/internal/peer/ice/agent.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index e80c98884..7b929c29d 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -22,6 +22,8 @@ const ( iceFailedTimeoutDefault = 6 * time.Second // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package iceRelayAcceptanceMinWaitDefault = 2 * time.Second + // iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete + iceAgentCloseTimeout = 3 * time.Second ) type ThreadSafeAgent struct { @@ -32,7 +34,17 @@ type ThreadSafeAgent struct { func (a *ThreadSafeAgent) Close() error { var err error a.once.Do(func() { - err = a.Agent.Close() + done := make(chan error, 1) + go func() { + done <- a.Agent.Close() + }() + + select { + case err = <-done: + case <-time.After(iceAgentCloseTimeout): + log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout) + err = nil + } }) return err } From 3176b53968f326beb7ddd21b76eb45bfdf66cb28 Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Thu, 13 Nov 2025 10:25:19 -0300 Subject: [PATCH 073/120] [client] Add quick actions window (#4717) * Open quick settings window if netbird-ui is already running * [client-ui] fix connection status comparison * [client-ui] modularize quick actions code * [client-ui] add netbird-disconnected logo * [client-ui] change quickactions UI It now displays the NetBird logo and a single button with a round icon * [client-ui] add hint message to quick actions screen This also updates fyne to v2.7.0 * [client-ui] remove unnecessary default clause * [client-ui] remove commented code * [client-ui] remove unused dependency * [client-ui] close quick actions on connection change * [client-ui] add function to get image from embed resources * [client] Return error when calling sendShowWindowSignal from Windows * [client-ui] Add commentary on empty OnTapped function for toggleConnectionButton * [client-ui] Fix tests * [client-ui] Add context to menuUpClick call * [client-ui] Pass serviceClient app as parameter To use its clipboard rather than the window's when showing the upload success dialog * [client-ui] Replace for select with for range chan * [client-ui] Replace settings change listener channel Settings now accept a function callback * [client-ui] Add missing iconAboutDisconnected to icons_windows.go * [client] Add quick actions signal handler for Windows with named events * [client] Run go mod tidy * [client] Remove line break * [client] Log unexpected status in separate function * [client-ui] Refactor quick actions window To address racing conditions, it also replaces usage of pause and resume channels with an atomic bool. * [client-ui] use derived context from ServiceClient * [client] Update signal_windows log message Also, format error when trying to set event on sendShowWindowSignal * go mod tidy * [client-ui] Add struct to pass fewer parameters to applyQuickActionsUiState function * [client] Add missing import --------- Co-authored-by: Viktor Liu --- client/ui/assets/netbird-disconnected.ico | Bin 0 -> 5056 bytes client/ui/assets/netbird-disconnected.png | Bin 0 -> 7537 bytes client/ui/client_ui.go | 70 +-- client/ui/debug.go | 8 +- client/ui/icons.go | 3 + client/ui/icons_windows.go | 3 + client/ui/quickactions.go | 349 +++++++++++++++ client/ui/quickactions_assets.go | 23 + client/ui/signal_unix.go | 76 ++++ client/ui/signal_windows.go | 171 +++++++ go.mod | 39 +- go.sum | 520 ++-------------------- 12 files changed, 734 insertions(+), 528 deletions(-) create mode 100644 client/ui/assets/netbird-disconnected.ico create mode 100644 client/ui/assets/netbird-disconnected.png create mode 100644 client/ui/quickactions.go create mode 100644 client/ui/quickactions_assets.go create mode 100644 client/ui/signal_unix.go create mode 100644 client/ui/signal_windows.go diff --git a/client/ui/assets/netbird-disconnected.ico b/client/ui/assets/netbird-disconnected.ico new file mode 100644 index 0000000000000000000000000000000000000000..812e9d283d096d823824fd66e6533359fb375b4a GIT binary patch literal 5056 zcmbtYi$7EU|39ci7|A8M&h;ZCN^PPP$}QB~%Kd&X z5v|!rNJVU@G}rMvKEFTVd-m94XRmY4>+*cP-p|+TeEE+o=GrUFHx3^L5htA5B=438qg#UQWPc{T#JKG-w#US$_RqN6XAC0Rg`8Vf zE84eI+8+kT?CY(*w6|8$v_58qt6?@yt`rRp4o+W0p+KAe`!VQB69F>lMHYEDU8EmB z7N@tePkbIZ{1AEdzEZ7BL{up;Ey@4@PDO64l+0{y6W^~71xI8AOk)V?GiS9KG!e-4 zwjjCp;E}0K3Bys`GmDPHa?G@xr-)fPk3C+p&GpsEgf&dW(Pj0RYnxJ7DEWCOedBaWq|c2bvwe+p6o9FF)4nvq zKCNMCvD4f>DK8AcQqF$g!Jl1wAen&vLg$#`0rKDmcRS+plHAfvew8*DY=&-Q;H)N! z4SG0urfykX;0$2Y6*6H4x(9;Ks%gaw*zCx-1Ft@;OnIe|`oPh#0CB;+ny;V~3-~IH}MFv~PgfWlpAHKd3FhEOs=&?6y z^0-#;tnI|q6urN_mT+&yl>D>p2u$}RJWgmfQsv3~#@#hj+(((pdoP;8@yS@CpHW>~ z`vH8ezSTEIE!*kgnIm_5p4^7a%+2*m32y~YpG*Iv_ZFA}Z5IsBrS;65dcS@G#l5JP z`jmP-gb8F4hjIl#-SDDQ6$f*-gYzvU%9 zE{yj>S@mQ=BSiyw1Mzw%3w2#vqqyV76CBh2BS8AyyLFyID?$^si8xve6t^Fr+bP!628V)ipPv2q0T+VR#CGb|F{A1Rhg^*|SA@Oz@>iB_**o% zpfBiko{N!-aLX`zq(pGi2$T3QJUpz0YNBvspB-^8bAf?K;{;q-MQJlxsI>6v|*J{;CB_i|a?cTp$9zi!hUZspw%`F+PgVm}^< z8h-a~@8c&=M!LH8B)UmQTw8u+o!()a6Pv+Ll@Ppu)6|Sw{tK!ow>RY|yYUiP?hz52 zc^+FVTZy)>7pDqX1epBN)>h==LcwsdxrIgFY81mXkxF7F)^P7x55ZWW4V(Dn3yc=Q z-l0{l?X_rhQt+)?#R$k2lz;Z0u@Q}Vrn8?=uKJGdk6`aM;CP(2x>jV`1w+;T%}VPTNxcY28R zZ{NzSc$o)C;)|ZDa>Z*mUc4T?AB%*I-V6z`>mNFZ|M3006e2-Nx&cJ-r(rx6UhJFQ zPH+ZePU^s(=0}JpNH;C3mt}ZCIpLajgSG{$V56ucw;#0x7hUAklh~_acMN7OJP2_7 z+@Y&vsduVixOfD*KO`1^#P$~Fr0n#tYkb?(Yb1}2r=%>6P2qD+d}}+xF_qo=HTBEV zL{K=h_~n1*;9YHLXrgs`b~_Y6y+HJR+tdQxXAC?Y_SU0NC^X8m)aWsb*hjKb0CIvC zGPu9k#V)Z$HFB~jHC2EyVCnJz1agE2K=G#7me02n#MIT*;XFZbtgOP;JK08f*l|H& z83H=Vr0%P|y?yb7|A0%9PoeO(=YvO(i`WCwc`05yPQ9>2DK52G5=zBsXs}ee>uK86 z*WBy4G@aDt^pz@YD}B*c+dAe%m+iB!KZ8B>QkaIE-}UAmw87lJBJfqf{l%+fs1F z8SMk~;5<%0M>(dZv*89}wWOBJrld`26p*ZFCtWVv9oeUN3L$2a4brtTgbcSyJXCUH zkRqsyq`paXECHp~5b~TF-0`8^cn31XbDp8SE_(9*ZI*|y@w|VFewF~BN0kje{&8r# z>a>D_F&rBqs=sV2zxB>q89=&~THcc(wWwa?4^2~&0y!>g@9yoY-#44BK(RqV=3*UU zwXUvCd%&er+=LGhAKZ@_S&iQD_!$|W(Vy7j)XxXxhx5fQEcng{#O%wyc*1#6EK3sv z@)QqXr5Ue-PI5j5=-1cRkCbEP^sQEOIpQNDC&04V6FV*_ti)Y>qnd!Mzru5LCA{ew zOTX~>_d_pnX@^ui>If(_vHI;>oRZFO^~fCNy@4~j!?B7H*NT@tmc632F0Qb5pZ z{{xL$dk3A^@{iex19kyI8^tq>Jfn}|uJ>dLy8ffK;!DgjYlN1YQ3Kf1%T>rSHoPfm z1`>p1V5;NM^pc*oJ4u=~*nZHV-2Oj@-spUvdpf{0NFkd%T z4{Gz{RjjG1`lCBW>css!SwZV_7cW2;f19*+=?1fS#*06#hv+fxge!Y<=mWipO|eY{ zbP8RNA8!5oz@&R7&|}HNWIDi}!~+apL9R(h9F^WW2u=wmDY1As5gev`4g$%^^%IND zgC+PEX^r;Adz{F0Z&j_{?Oc!ID_|strNyP-u62Z8veca7cNs%3eMG@0@}< zuoFZ_rvsL;JgdCukd_^dUcfU;$qw$dGcWu`y>)Fg1_8l5>n*{1Z`uVgO4`z1AGing4;(AB5Xw1;Kde~O(a~`;?R8yhvko-PSd?*_hSit?2G#sW&uFd07<^^WDBRyh)jN!&zVxV5{%R;b@cf00Pe4zjbwc=wp*%8lJR>jfSPCrvZ|{tfw> zJK%yCq&mRe(`0rAdrL8Zd|SxG74h)?b0=YtzZytjgUMtPJSxo&!U5B%+?L1dlW?At zw+g^0SYHW1#z%1UGrirXn;Ya)$LvI#$jK(>&nfIw41nz!3Ea4TrA{K@bJNlpX6-d1 z5vr>Lxh>e5nm-rMI#a?6!TfXf?~=w!J&9S}+_?lb-q$A8q+pftLgk7Io zy4L(@HLAX`F{|mD^XBwF$CE#BB9O}DH2HxIi}6a5TF%PE&B;EO8|8^^M(HyenN31= zAild(9MMm1U{y88%%{egMsc~+yhQgObT|YDL z*hb5d`S({hiy#2>LMf3Mx-(Yh@PrjoEGVjj2xKlz(6srobwbd@FRPO=cdK@I;DF3G zGw#wypBiru4}FfI^0C=QP6YR(&f)8^Gcz;gih?Z%xj_o2Pb+G@l)Pim5y9O!GafC2 zfjVLYJ4Q(@^5bkZhMI|@hqtPj5!VI}P$j;w=`stm0e*fMxbPtD<+ei0)(E9Z=n4kV zz%Ga@`5UB`J6SuTm0bd$*rvtz$O*Zlz)4nq{X6Gki?s)0ag?R#boR19Z@1fCO z!#(-8`T}zg{JTU*8|MQUAe=~;f52m*IDpg=HySQCP9rwEgL#%rcnWqhhG}MdVD9z; zy;760pr}t$0q7y;t%TVFE{4Z^eSK?kWjoP-j%kJ``0ntGsX>!%cFY(Hy=4>r2T)@^ zrO58cgOv|SI%CY}Yy8+b7Aq|dlhKql68OHUsY$b%i?AOs1>UcVFnp!kgP$55eSs*i zs1X0Fpv&gbVGx^3?f)skrzxea%Nm#SkQeH6a5dZ!zlL#8X8Ss&1pDY&O zCSD){-(OSCe#&-06+p=;yNIdfjQdH{$}fO0&d=S7rN+suZi-w}6&vs0=RPNrwo?a`L22A%*d57c;21=15}9?ktrkLheG# zynCP+AlDK?7Zw(*`BGd5DVMv2A;X*ji%;H41GGji7hJ%S$|s<&(}d4XnxkDDC;+e%G{H+S_Y0e2hGl|Fxq)#s;ORGgch^=Q)zZB=5dqa6?Dm zZxZ?lP8_)N=jFb6fxBuhhZ*3#DWk^}#0Xxf=E_c;HFtyRpPhX(Nl}{st0p2~4nCUx*N7L{Z1yuY-~ayB3e=i+|Lo;vZtUta M);HIyz`Mr%AJn`~GXMYp literal 0 HcmV?d00001 diff --git a/client/ui/assets/netbird-disconnected.png b/client/ui/assets/netbird-disconnected.png new file mode 100644 index 0000000000000000000000000000000000000000..79d4775eab038831d38abcc2845f671adc7a1b4c GIT binary patch literal 7537 zcmdsci9b~T_y23gsIe6yYs|=&!VoEmL6+>KvcwRw@9R6ujJ1e}A|<=X63Uu2(nhu{ zF@)@dDU7kq%y0C5e?IT`@BR3G|AFs)-1~T)d(L^D<=%78oO>tM$UuvQk(Utwz;a3Z zq6q*H8VLdL1GMG3Pti@<0&`T=R|TLlk!jDCo_3C~(>Bou;Fc%=xF`U2Xe`_U0RAWd zmTUk}d;|cOSH^2&C7K}0!R*paeSIK7qv3!a!Vc(Y6lC8Egcs~v+xMpf;rlOb0uld9 z2MWM#X8`+4$AY%*Z`!nDAM@u*mks&bVm9-1M`@XJ=gvu^WTjF;+zNNC^Df44v5^mD%XwW{?@z}-T=}+~rZAx}5xmxf%{Aqp;p3edADySMS^}Cak1ATi##q1N@zcllPBW9J zZ>V!m6g;1qWnq{jvCS_YUWMJng#6SHg7n5cHSkfzMG8cQF#SR0Yti4UszWACuC{cR zihwb2^#u!qQ?CNdL19E{jnZXu>w#6pr!r)?R-%CW_j8m9lKY~0Gg9vS$K>+noh6IZ zhjO_1C`vFg>Cj#SH|a@ZcIIsG8|4;wkY|rVg%BjEFvs5!7DqQR68*sYSolZ+pTePa z24V0^uFB&n?m)ADe%a|RX6!cepR5EBAMA%GgDce2uIUnaL$KaMzU>(0l!tpULmI6I zAY+@J<)x)cs@A-FXK(>M$G4ZuTUOH?jcD{W}@y?gft*|jD2g0=Vz~3Of&|gzWO}Wp2?`_1Sf8Q5&6Q z9O#^MT=Ey?NbN49_4%F&nG@bgiQvZrE@|uQjv6u@-cMUf`ya)g`8+}%O2g=H<#mdZJLg5P^r=GQx)9JANw5UN1jDe#eO{M z4s_`Vgq|*<;N6+dae-~=@@zT}mx@{qm$gPc1mZn)BM13m!LfX0=MzIri|ezPmXML| zf_IscKQiVHJB{l?8Fc*NQ;nXM=Ud0IOlmP#v6A7ett>aTXtmosTu?WkPxODIFjoIO zfWX~0sp6hRec4}f*j_kd$8(mJDPe{UcMSAo81|+Q-DRj+Q3=mLITi3|Z*H|D=<0m^ zv~sP>(SuHVy54&zEed_##0;Nt7u7CYKo;$ZSsce)rM8ZqrTS+{X=Jr9p}9B|Kl^{v ze{Emndr+)3J!Oj^KVDe(iW&uJzH&`SY3s{Uxd;O4^y;1T?x~WY>vSL3bHBLRTD{9= z*5;=&jMX;;VZWowIgQSKAaGr$3-sCLbUi-|UeYc4wXL{WQjJ+BXEHD7P2Oso4}BeC ze)2~HpT>%&X{@Mivz}K*A9$J4XB<$`y&!ENBuw}<+o|pZd%_B38%oiUaok{{BAOzZ_iX!iObxzSm~1VYKGDj-!=j z#CVLBX11I!gh0!!n-qEWN_hhzb(=ehvSt?mZD5syPxD;=*5{SKcK7+k*8+9Hr zk1Mw*Vj&cE2ts??5>X?|7%3f7ngW-eeE7kU8cMqe;m63U(YDLyvvFwC{82x}uY!=e z`LcR)T!Du$2x5JLv*Yi52~B8T{Zar+h{ts>*&G=8s3l~R@l{;lOAQuw)@zv=D=LzD zzc=b?VU+=`39<9If-an@RE|kM>b{rd-3gdmF`@s!EmzmVV(itsS81Nq;lg#VuD^F_ zbC(QbpY&Alu-&S%)G9>uLQhq8o}! zH-_6tZO@DEc`N&t9oE9p^$)BJ*Hn>6v2)B-(X~O=5s3?9n6cS_&NyJh#%c7{T|IrS z`@C;fpu-5go@J37*+bb{-I=lTyk$LP`4djO{Yfv3!{OPXnIwFkB@%plr|Yrn79N*!7&pzRQdYl-|o_OqXx|JRrB% zx2dq920=@Qr)t7(*0ikmE);9eTm*#g9MB++ukIGC1+UbzrQ0zn{U+Qp-MG|S(9qZ9 zI1z7A5QL=XIPpE8VZ3&#{F3`whD&PMk2yS(L-{Db3MxHdU5=S$C$qIR+s0K`u*xmX zoxOI~%B}qmBDyjT%IwV=pBn@(nJyjqpk24S^>a0=yi6`!DPenOGVw!`6EOGBVd%11 z*;B+1wH$-0H<0yQkX%z4Lz98n)O;=$M2Re2++@0=%5O39g5BlIQ=UJVn6bm5HoYal zCi+W}+HGZ1NJ$QGYU>VHA?;bLBT50YvqhyLo@nD}J&R(YGDDA+IHAkcOQsxV+Ja0AU2cE?n$HNTqCIaFbfGp?Z`wnlKa7EXdI<>Nhb z*ed!;*eZg&qmDlLE~(b6fRYsW5^y0ri5crJuliy6vg~uzO74CVH)^ASX>P+Zs}jGXrzLhs0+@+0K-N=dvn!9ybvAMkyR%hMrk0wQ z8GF41wm9cXoYNn9ctKj#wBbo=>mtInm@gL5Zaq#WT1rb}&dD#<SHxpYN-k; zYpEfWME8&J5tV}{%MuklT7)p3Fu?SzNUdYed~B{Y(frC>o_pLxtEtb=DUX zvjBn&LS_rZC)?K#OVfk!6UDtZYGXTAYRyW1m6=5yJ4LnTrTiLb;smGia~YdlC5k&_ z>dC*3Rk%EbN)w*9F|Tjd6^Awd*s=hu$u}7Ea?k7yvP30`!C83bO71!Te6$6{(F?p{ z%^p2gW}b2bSiHlNozJ%!_VnqH(B)vxt8Ua zt;M#p+>lnCXKFSry3y*23d9j_{n3a>m5Tbq4UYx_#_~Os^BI0|mdZpOogT-2DBmUZ zo_i%r(0{YJxyf)d?q)*>d25xHcyyauylXm3hJ$0_=R<1%-3sW zaNbYe;vOwG?C4JdGd4!Z7-ud!RCbTEVPWyictobbp<9!eAC+Ie(K#F;2R&NCsp}+D zT=)L{`<(iwxX)95_;zV#dPh>Hf>AFo`YM|f^b+%9utt#!a(*s4nImaa`?`uE7^?pDbRRm-Jw6pm=QqXBH4FVDWc<>a+1|Ji27~%W#V}el7}Rtf=#8#_R8(U zf~z*WxA&AWL(hq>&Dhm>&z5%peEZDdu6c~WlP5irAmtdM`e+KUe(2FuAyyw^pk zD(jz%Hvz2dqLW5P&dSmci8Rga2IgD~l^X4vwFBc}ua-(~huT*`ah0&0dlJI!Sq*Q# zFUjCAr2*HLrhEm;0AfcEZX-Wc(;4kjTw<{Fax7E z{aBV#yZiIS+d#7W#M5`*t~zYil049+>*P)|huwULy)Yrm8Ja|SUqAlzs{htnRL_F_ z-L(T3 zXw`6?eZl1*j*vLIa?M^EVkDgG5tu+LBFu1uu|{82Lk#-;7Bm;WekF3g3(0^xt!)sP zUpz@%%$#%e*_C&6I4u$_<#?dRESL6xnFKQ6M5;`XUXGVuMz_oENlq?=<|P~jIhBG- z!=bN1DqjWiLP+wS0Q1GbG?Bt7L)+$5W1rx?{sP^U!^))A8KFGyUCtbBI&t>rUp;t@ zq5`p!?d(pr!8{50@~^V&Kto6#-x&qYXO4x*gJ4B_agfWj?hs)PZH+qG|4c1gt~hBd zECzK!y*v8z+2YpcBd6vRlY@H=-@YkH9yu`uDOo_;jA$T!!S&nXP_oH+Y$OnOdB`MjN{JdmnIBgHVl14Z&%d|vf4|CaMB6)d0ye$v%=?Ktx9u|toc6&0)1PCgm0^PcYQ6N&7+ z3e8qev7PL3JfH*j1B(j-vJg~2ErzWchQ6=ezCwHX^(Gjue>?$uK^riq)(cao7!XVA zYDj_^qJ08-|KmkyHpu@ds(w4?wEM$?j@=ottkP6g@w_i<%@GHZ(1@|MS{MbR56#x# zD=4N`EtxVxMqyRVy4eU2z_#+yqNvu^_2$}A{}mSVXW{BuXU-+OqQf)>^$9=kKPyD| zD3*nCtS3v5bm)q}WSd|jy(E{Sf>knLhn zxA<`c2a=X1SW94N7V+0)uON$*{4Y7x)$$DC2C$yY_&<1Gc{Xp~3?SYiK441%;&~CF zi}5Yn_%P0oiYs%wP!6cyHRnYG9RGHk;a0}Glt>s_!_v;{+q)d+4qMlT6EF_l)R|tPC2z_zeuC(#S*SQzzob5`q}DD7YyDPkSSuj9})%P#}SaEyohZ|6F0>t?UdE*9=LVT2{B+|2m-O) zTPnqd1$kFs!oWmUZtqR>z>}%+V%3*q4rokPUAMup&nnd-do#r4-tX+-M5F0uR)CE& z*_58+xfY}0=he?o6jF9f6)m*TilPhoDPg?P+r^QKOWZBzNxx=wuZqcd(x=NZ`h5 zJzm3p%Ih3mphxfgn25IZoY#Tt87Lhm{VE>|OT9H-G@-S^JJ)H#$L*kw1fCsH0pT)I zb-O);7g*VMuN#Wio#X*8g;2olITj|Gpu086oyXWmxNR7=JTj60>eQw47A%M>%AO79{F%OQ&KHMM}ZfdUmcj*3)Zk z!*SgkDGv7{?cYCdvx}90T`Mxh9&B>+gChKQ^^aBX3wP?l(M2piWM!#1MTxGoe9GcY&;)ObcGCRw^x%X_M1+QIGrz(OjwXAa<6Wd( z=0L2NAjvFXCFkB3d2KGFRWwqHaedIRBtea?6+y}o%%EHs*(g%$wUL`>xxLEK1nzuZ zaqJgUOOVK?y=#_@NxBa}aWO5GO1szmBYSfh`(l31edent7xNDMq7>1i6-kAu>GWyG zqQQg^v#sqT>Co}QmejZ*_@i&+EaL5c>dB^Cn=WlxIH7fT?Mt161^6E9y( z2cy}XBuvO@^0~9q#x6i^PZj(yzsac|pzK2X^0PS}I6*$i`yt#^Sce4>N<5s;up=!h z4>`cQ>1_O{+}7gwYa9X@v(1~k$a4kGcE1rWu2!4X7Q)iUB$`qrN4w>yMMqce2UAuL zjia+~9k%e|rEtu-)`{NH0UuTzzuhMtK#DQ5RA+ytam7h>9&(!Kdh+O&P_s1ZqLU11 zFXider>-lyFHSeG-_fK)JMs>8DqHa>W0fcPXQsjGyY2DrVcYlz_?~mKJ4_dmV1cME z_N}PA_}&oP7`k6E`$%XgzZgB6z2rl5a8$LZi zmuk$Ba~KJ>rs+O!ZRIts)Mr_p42~-lN>_( z)H|$m*kv98ix>nb(kyW~r_<}%v_l1qmHUSxwN^JH=?7<7JmTUE~xm$+F7#7mlaAIe#%+Ib5WQ`xx0NDk$u2# zOTlH>n%%c%a%stBwwhzEggrswU<3+!qnYa7_)2**<}stJ7#i8ydm{t&0K+kN(y)EP ze{H`Xph$sx8b!&R`@Nvhk&4=h-h$HFK%wmQB z5$q0ofR`@qj#?Z^aqP(y*>7kiGZJO+JGlQ-dj=~Bdt4U2cyIH?+x Date: Thu, 13 Nov 2025 20:16:45 +0100 Subject: [PATCH 074/120] [client] Use stdnet with a context to avoid DNS deadlocks (#4781) --- client/iface/iface_test.go | 21 ++-- client/iface/udpmux/mux.go | 6 +- client/internal/dns/server_test.go | 6 +- client/internal/engine_stdnet.go | 2 +- client/internal/engine_stdnet_android.go | 2 +- client/internal/engine_test.go | 6 +- client/internal/peer/guard/ice_monitor.go | 2 +- client/internal/peer/ice/agent.go | 5 +- client/internal/peer/ice/stdnet.go | 6 +- client/internal/peer/ice/stdnet_android.go | 10 +- client/internal/peer/worker_ice.go | 2 +- client/internal/relay/relay.go | 4 +- client/internal/routemanager/manager_test.go | 4 +- .../systemops/systemops_generic_test.go | 4 +- client/internal/stdnet/stdnet.go | 112 +++++++++++++++++- 15 files changed, 153 insertions(+), 39 deletions(-) diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index e890b30f3..6bbfeaa63 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -1,6 +1,7 @@ package iface import ( + "context" "fmt" "net" "net/netip" @@ -9,13 +10,13 @@ import ( "time" "github.com/google/uuid" - "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/stdnet" ) // keep darwin compatibility @@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) addr := "100.64.0.1/8" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) { func Test_CreateInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) wgIP := "10.99.99.1/32" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -166,7 +167,7 @@ func Test_Close(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) wgIP := "10.99.99.5/30" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) { func Test_UpdatePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.9/30" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) { func Test_RemovePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.13/30" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) { peer2wgPort := 33200 keepAlive := 1 * time.Second - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) { guid = fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - newNet, err = stdnet.NewNet() + newNet, err = stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/iface/udpmux/mux.go b/client/iface/udpmux/mux.go index 319724926..c5d2de4a5 100644 --- a/client/iface/udpmux/mux.go +++ b/client/iface/udpmux/mux.go @@ -1,6 +1,7 @@ package udpmux import ( + "context" "fmt" "io" "net" @@ -12,8 +13,9 @@ import ( "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3" - "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" ) /* @@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() { if len(networks) > 0 { if m.params.Net == nil { var err error - if m.params.Net, err = stdnet.NewNet(); err != nil { + if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil { m.params.Logger.Errorf("failed to get create network: %v", err) } } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 451b83f92..d12070128 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { privKey, _ := wgtypes.GenerateKey() - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet([]string{"utun2301"}) + newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) if err != nil { t.Errorf("create stdnet: %v", err) return @@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet([]string{"utun2301"}) + newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) if err != nil { t.Fatalf("create stdnet: %v", err) return nil, err diff --git a/client/internal/engine_stdnet.go b/client/internal/engine_stdnet.go index 9e171b0b2..1ebb5779c 100644 --- a/client/internal/engine_stdnet.go +++ b/client/internal/engine_stdnet.go @@ -7,5 +7,5 @@ import ( ) func (e *Engine) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(e.config.IFaceBlackList) + return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList) } diff --git a/client/internal/engine_stdnet_android.go b/client/internal/engine_stdnet_android.go index 68a0ae719..de3c80bcf 100644 --- a/client/internal/engine_stdnet_android.go +++ b/client/internal/engine_stdnet_android.go @@ -3,5 +3,5 @@ package internal import "github.com/netbirdio/netbird/client/internal/stdnet" func (e *Engine) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList) + return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 15ac0a947..d15a07f9d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -14,7 +14,7 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -774,7 +774,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { MTU: iface.DefaultMTU, }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -977,7 +977,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 0f22ee7b0..a201dd095 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) { log.Debugf("Gathering ICE candidates") - agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) + agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) if err != nil { return false, fmt.Errorf("create ICE agent: %w", err) } diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 7b929c29d..79f68d279 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -1,6 +1,7 @@ package ice import ( + "context" "sync" "time" @@ -49,13 +50,13 @@ func (a *ThreadSafeAgent) Close() error { return err } -func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { +func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() iceFailedTimeout := iceFailedTimeout() iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList) + transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList) if err != nil { log.Errorf("failed to create pion's stdnet: %s", err) } diff --git a/client/internal/peer/ice/stdnet.go b/client/internal/peer/ice/stdnet.go index 3ce83727e..685ed0363 100644 --- a/client/internal/peer/ice/stdnet.go +++ b/client/internal/peer/ice/stdnet.go @@ -3,9 +3,11 @@ package ice import ( + "context" + "github.com/netbirdio/netbird/client/internal/stdnet" ) -func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { - return stdnet.NewNet(ifaceBlacklist) +func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ctx, ifaceBlacklist) } diff --git a/client/internal/peer/ice/stdnet_android.go b/client/internal/peer/ice/stdnet_android.go index 84c665e6f..5033ec1b9 100644 --- a/client/internal/peer/ice/stdnet_android.go +++ b/client/internal/peer/ice/stdnet_android.go @@ -1,7 +1,11 @@ package ice -import "github.com/netbirdio/netbird/client/internal/stdnet" +import ( + "context" -func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 5d8ebfe45..840fc9241 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -209,7 +209,7 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { - agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) + agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 693ea1f31..59be5b0a7 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(ctx, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return @@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(ctx, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index d2f02526c..3697545ae 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -6,7 +6,7 @@ import ( "net/netip" "testing" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/stretchr/testify/require" @@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index d9b109beb..01916fbe3 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -15,7 +15,7 @@ import ( "syscall" "testing" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen peerPrivateKey, err := wgtypes.GeneratePrivateKey() require.NoError(t, err) - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) require.NoError(t, err) opts := iface.WGIFaceOpts{ diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 4b031c05c..381886ac6 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -4,17 +4,28 @@ package stdnet import ( + "context" + "errors" "fmt" + "net" + "net/netip" "slices" + "strconv" "sync" "time" - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" + + "github.com/netbirdio/netbird/client/iface/netstack" ) -const updateInterval = 30 * time.Second +const ( + updateInterval = 30 * time.Second + dnsResolveTimeout = 30 * time.Second +) + +var errNoSuitableAddress = errors.New("no suitable address found") // Net is an implementation of the net.Net interface // based on functions of the standard net package. @@ -28,12 +39,19 @@ type Net struct { // mu is shared between interfaces and lastUpdate mu sync.Mutex + + // ctx is the context for network operations that supports cancellation + ctx context.Context } // NewNetWithDiscover creates a new StdNet instance. -func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { +func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { + if ctx == nil { + ctx = context.Background() + } n := &Net{ interfaceFilter: InterfaceFilter(disallowList), + ctx: ctx, } // current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client // so in android cli use pionDiscover @@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri } // NewNet creates a new StdNet instance. -func NewNet(disallowList []string) (*Net, error) { +func NewNet(ctx context.Context, disallowList []string) (*Net, error) { + if ctx == nil { + ctx = context.Background() + } n := &Net{ iFaceDiscover: pionDiscover{}, interfaceFilter: InterfaceFilter(disallowList), + ctx: ctx, } return n, n.UpdateInterfaces() } +// resolveAddr performs DNS resolution with context support and timeout. +func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return netip.AddrPort{}, err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err) + } + if port < 0 || port > 65535 { + return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port) + } + + ipNet := "ip" + switch network { + case "tcp4", "udp4": + ipNet = "ip4" + case "tcp6", "udp6": + ipNet = "ip6" + } + + if host == "" { + addr := netip.IPv4Unspecified() + if ipNet == "ip6" { + addr = netip.IPv6Unspecified() + } + return netip.AddrPortFrom(addr, uint16(port)), nil + } + + ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout) + defer cancel() + + addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host) + if err != nil { + return netip.AddrPort{}, err + } + + if len(addrs) == 0 { + return netip.AddrPort{}, errNoSuitableAddress + } + + return netip.AddrPortFrom(addrs[0], uint16(port)), nil +} + // UpdateInterfaces updates the internal list of network interfaces // and associated addresses filtering them by name. // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one @@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I } return result } + +// ResolveUDPAddr resolves UDP addresses with context support and timeout. +func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { + switch network { + case "udp", "udp4", "udp6": + case "": + network = "udp" + default: + return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)} + } + + addrPort, err := n.resolveAddr(network, address) + if err != nil { + return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err} + } + + return net.UDPAddrFromAddrPort(addrPort), nil +} + +// ResolveTCPAddr resolves TCP addresses with context support and timeout. +func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { + switch network { + case "tcp", "tcp4", "tcp6": + case "": + network = "tcp" + default: + return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)} + } + + addrPort, err := n.resolveAddr(network, address) + if err != nil { + return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err} + } + + return net.TCPAddrFromAddrPort(addrPort), nil +} From e4b41d0ad70676b3f3f4a18f621b5467bd1c509c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 14 Nov 2025 00:25:00 +0100 Subject: [PATCH 075/120] [client] Replace ipset lib (#4777) * Replace ipset lib * Update .github/workflows/check-license-dependencies.yml Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Ignore internal licenses * Ignore dependencies from AGPL code * Use exported errors * Use fixed version --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- .../workflows/check-license-dependencies.yml | 73 +++++++++- client/firewall/iptables/acl_linux.go | 135 ++++++++++++++---- client/firewall/iptables/router_linux.go | 50 +++++-- go.mod | 6 +- go.sum | 12 +- 5 files changed, 225 insertions(+), 51 deletions(-) diff --git a/.github/workflows/check-license-dependencies.yml b/.github/workflows/check-license-dependencies.yml index d3da427b0..2a3e7d424 100644 --- a/.github/workflows/check-license-dependencies.yml +++ b/.github/workflows/check-license-dependencies.yml @@ -3,10 +3,19 @@ name: Check License Dependencies on: push: branches: [ main ] + paths: + - 'go.mod' + - 'go.sum' + - '.github/workflows/check-license-dependencies.yml' pull_request: + paths: + - 'go.mod' + - 'go.sum' + - '.github/workflows/check-license-dependencies.yml' jobs: - check-dependencies: + check-internal-dependencies: + name: Check Internal AGPL Dependencies runs-on: ubuntu-latest steps: @@ -33,9 +42,67 @@ jobs: if [ $FOUND_ISSUES -eq 1 ]; then echo "" echo "❌ Found dependencies on management/, signal/, or relay/ packages" - echo "These packages will change license and should not be imported by client or shared code" + echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" exit 1 else echo "" - echo "✅ All license dependencies are clean" + echo "✅ All internal license dependencies are clean" fi + + check-external-licenses: + name: Check External GPL/AGPL Licenses + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: 'go.mod' + cache: true + + - name: Install go-licenses + run: go install github.com/google/go-licenses@v1.6.0 + + - name: Check for GPL/AGPL licensed dependencies + run: | + echo "Checking for GPL/AGPL/LGPL licensed dependencies..." + echo "" + + # Check all Go packages for copyleft licenses, excluding internal netbird packages + COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true) + + if [ -n "$COPYLEFT_DEPS" ]; then + echo "Found copyleft licensed dependencies:" + echo "$COPYLEFT_DEPS" + echo "" + + # Filter out dependencies that are only pulled in by internal AGPL packages + INCOMPATIBLE="" + while IFS=',' read -r package url license; do + if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then + # Find ALL packages that import this GPL package using go list + IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath") + + # Check if any importer is NOT in management/signal/relay + BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1) + + if [ -n "$BSD_IMPORTER" ]; then + echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER" + INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n" + else + echo "✓ $package ($license) is only used by internal AGPL packages - OK" + fi + fi + done <<< "$COPYLEFT_DEPS" + + if [ -n "$INCOMPATIBLE" ]; then + echo "" + echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:" + echo -e "$INCOMPATIBLE" + exit 1 + fi + fi + + echo "✅ All external license dependencies are compatible with BSD-3-Clause" diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d78372c9e..5ccaf17ba 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -1,13 +1,14 @@ package iptables import ( + "errors" "fmt" "net" "slices" "github.com/coreos/go-iptables/iptables" "github.com/google/uuid" - "github.com/nadoo/ipset" + ipset "github.com/lrh3321/ipset-go" log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -40,19 +41,13 @@ type aclManager struct { } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) { - m := &aclManager{ + return &aclManager{ iptablesClient: iptablesClient, wgIface: wgIface, entries: make(map[string][][]string), optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(), - } - - if err := ipset.Init(); err != nil { - return nil, fmt.Errorf("init ipset: %w", err) - } - - return m, nil + }, nil } func (m *aclManager) init(stateManager *statemanager.Manager) error { @@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering( specs = append(specs, "-j", actionToStr(action)) if ipsetName != "" { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { - if err := ipset.Add(ipsetName, ip.String()); err != nil { - return nil, fmt.Errorf("failed to add IP to ipset: %w", err) + if err := m.addToIPSet(ipsetName, ip); err != nil { + return nil, fmt.Errorf("add IP to ipset: %w", err) } // if ruleset already exists it means we already have the firewall rule // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. @@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering( }}, nil } - if err := ipset.Flush(ipsetName); err != nil { - log.Errorf("flush ipset %s before use it: %s", ipsetName, err) + if err := m.flushIPSet(ipsetName); err != nil { + if errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("flush ipset %s before use: %v", ipsetName, err) + } else { + log.Errorf("flush ipset %s before use: %v", ipsetName, err) + } } - if err := ipset.Create(ipsetName); err != nil { - return nil, fmt.Errorf("failed to create ipset: %w", err) + if err := m.createIPSet(ipsetName); err != nil { + return nil, fmt.Errorf("create ipset: %w", err) } - if err := ipset.Add(ipsetName, ip.String()); err != nil { - return nil, fmt.Errorf("failed to add IP to ipset: %w", err) + if err := m.addToIPSet(ipsetName, ip); err != nil { + return nil, fmt.Errorf("add IP to ipset: %w", err) } ipList := newIpList(ip.String()) @@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { return fmt.Errorf("invalid rule type") } + shouldDestroyIpset := false if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { // delete IP from ruleset IPs list and ipset if _, ok := ipsetList.ips[r.ip]; ok { - if err := ipset.Del(r.ipsetName, r.ip); err != nil { - return fmt.Errorf("failed to delete ip from ipset: %w", err) + ip := net.ParseIP(r.ip) + if ip == nil { + return fmt.Errorf("parse IP %s", r.ip) + } + if err := m.delFromIPSet(r.ipsetName, ip); err != nil { + return fmt.Errorf("delete ip from ipset: %w", err) } delete(ipsetList.ips, r.ip) } @@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { // we delete last IP from the set, that means we need to delete // set itself and associated firewall rule too m.ipsetStore.deleteIpset(r.ipsetName) - - if err := ipset.Destroy(r.ipsetName); err != nil { - log.Errorf("delete empty ipset: %v", err) - } + shouldDestroyIpset = true } if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { @@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { } } + if shouldDestroyIpset { + if err := m.destroyIPSet(r.ipsetName); err != nil { + if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("destroy empty ipset: %v", err) + } else { + log.Errorf("destroy empty ipset: %v", err) + } + } + } + m.updateState() return nil @@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error { } for _, ipsetName := range m.ipsetStore.ipsetNames() { - if err := ipset.Flush(ipsetName); err != nil { - log.Errorf("flush ipset %q during reset: %v", ipsetName, err) + if err := m.flushIPSet(ipsetName); err != nil { + if errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("flush ipset %q during reset: %v", ipsetName, err) + } else { + log.Errorf("flush ipset %q during reset: %v", ipsetName, err) + } } - if err := ipset.Destroy(ipsetName); err != nil { - log.Errorf("delete ipset %q during reset: %v", ipsetName, err) + if err := m.destroyIPSet(ipsetName); err != nil { + if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) { + log.Debugf("destroy ipset %q during reset: %v", ipsetName, err) + } else { + log.Errorf("destroy ipset %q during reset: %v", ipsetName, err) + } } m.ipsetStore.deleteIpset(ipsetName) } @@ -368,8 +387,8 @@ func (m *aclManager) updateState() { // filterRuleSpecs returns the specs of a filtering rule func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { matchByIP := true - // don't use IP matching if IP is ip 0.0.0.0 - if ip.String() == "0.0.0.0" { + // don't use IP matching if IP is 0.0.0.0 + if ip.IsUnspecified() { matchByIP = false } @@ -416,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return ipsetName + actionSuffix } } + +func (m *aclManager) createIPSet(name string) error { + opts := ipset.CreateOptions{ + Replace: true, + } + + if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { + return fmt.Errorf("create ipset %s: %w", name, err) + } + + log.Debugf("created ipset %s with type hash:net", name) + return nil +} + +func (m *aclManager) addToIPSet(name string, ip net.IP) error { + cidr := uint8(32) + if ip.To4() == nil { + cidr = 128 + } + + entry := &ipset.Entry{ + IP: ip, + CIDR: cidr, + Replace: true, + } + + if err := ipset.Add(name, entry); err != nil { + return fmt.Errorf("add IP to ipset %s: %w", name, err) + } + + return nil +} + +func (m *aclManager) delFromIPSet(name string, ip net.IP) error { + cidr := uint8(32) + if ip.To4() == nil { + cidr = 128 + } + + entry := &ipset.Entry{ + IP: ip, + CIDR: cidr, + } + + if err := ipset.Del(name, entry); err != nil { + return fmt.Errorf("delete IP from ipset %s: %w", name, err) + } + + return nil +} + +func (m *aclManager) flushIPSet(name string) error { + return ipset.Flush(name) +} + +func (m *aclManager) destroyIPSet(name string) error { + return ipset.Destroy(name) +} diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 305b0bf28..1fe4c149f 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -10,7 +10,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/hashicorp/go-multierror" - "github.com/nadoo/ipset" + ipset "github.com/lrh3321/ipset-go" log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" @@ -107,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1 }, ) - if err := ipset.Init(); err != nil { - return nil, fmt.Errorf("init ipset: %w", err) - } - return r, nil } @@ -232,12 +228,12 @@ func (r *router) findSets(rule []string) []string { } func (r *router) createIpSet(setName string, sources []netip.Prefix) error { - if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { + if err := r.createIPSet(setName); err != nil { return fmt.Errorf("create set %s: %w", setName, err) } for _, prefix := range sources { - if err := ipset.AddPrefix(setName, prefix); err != nil { + if err := r.addPrefixToIPSet(setName, prefix); err != nil { return fmt.Errorf("add element to set %s: %w", setName, err) } } @@ -246,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error { } func (r *router) deleteIpSet(setName string) error { - if err := ipset.Destroy(setName); err != nil { + if err := r.destroyIPSet(setName); err != nil { return fmt.Errorf("destroy set %s: %w", setName, err) } @@ -915,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) continue } - if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err)) + if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) } } if merr == nil { @@ -993,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string { return []string{flag, strconv.Itoa(int(port.Values[0]))} } + +func (r *router) createIPSet(name string) error { + opts := ipset.CreateOptions{ + Replace: true, + } + + if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { + return fmt.Errorf("create ipset %s: %w", name, err) + } + + log.Debugf("created ipset %s with type hash:net", name) + return nil +} + +func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error { + addr := prefix.Addr() + ip := addr.AsSlice() + + entry := &ipset.Entry{ + IP: ip, + CIDR: uint8(prefix.Bits()), + Replace: true, + } + + if err := ipset.Add(name, entry); err != nil { + return fmt.Errorf("add prefix to ipset %s: %w", name, err) + } + + return nil +} + +func (r *router) destroyIPSet(name string) error { + return ipset.Destroy(name) +} diff --git a/go.mod b/go.mod index 14cd49192..2d7e0d31c 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 - github.com/vishvananda/netlink v1.3.0 + github.com/vishvananda/netlink v1.3.1 golang.org/x/crypto v0.40.0 golang.org/x/sys v0.34.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 @@ -59,10 +59,10 @@ require ( github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 + github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 @@ -236,7 +236,7 @@ require ( github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect - github.com/vishvananda/netns v0.0.4 // indirect + github.com/vishvananda/netns v0.0.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wlynxg/anet v0.0.3 // indirect github.com/yuin/goldmark v1.7.8 // indirect diff --git a/go.sum b/go.sum index ee5027d61..f4b62dff0 100644 --- a/go.sum +++ b/go.sum @@ -316,6 +316,8 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= +github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= +github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= @@ -356,8 +358,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= -github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= @@ -519,10 +519,10 @@ github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYg github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= -github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= -github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= -github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= +github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= +github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= +github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= From 0d79301141d9b856d05a000e1e5a21517f5aad5e Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Mon, 17 Nov 2025 15:28:20 +0100 Subject: [PATCH 076/120] Update client login success page (#4797) --- client/internal/auth/pkce_flow.go | 11 +- client/internal/templates/pkce-auth-msg.html | 155 ++++----- .../internal/templates/pkce_auth_msg_test.go | 299 ++++++++++++++++++ 3 files changed, 386 insertions(+), 79 deletions(-) create mode 100644 client/internal/templates/pkce_auth_msg_test.go diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 738d3e34f..48873f640 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -192,17 +192,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, if authError := query.Get(queryError); authError != "" { authErrorDesc := query.Get(queryErrorDesc) - return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) + if authErrorDesc != "" { + return nil, fmt.Errorf("authentication failed: %s", authErrorDesc) + } + return nil, fmt.Errorf("authentication failed: %s", authError) } // Prevent timing attacks on the state if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { - return nil, fmt.Errorf("invalid state") + return nil, fmt.Errorf("authentication failed: Invalid state") } code := query.Get(queryCode) if code == "" { - return nil, fmt.Errorf("missing code") + return nil, fmt.Errorf("authentication failed: missing code") } return p.oAuthConfig.Exchange( @@ -231,7 +234,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, } if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil { - return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) + return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err) } email, err := parseEmailFromIDToken(tokenInfo.IDToken) diff --git a/client/internal/templates/pkce-auth-msg.html b/client/internal/templates/pkce-auth-msg.html index 4825c48e7..175a6f05c 100644 --- a/client/internal/templates/pkce-auth-msg.html +++ b/client/internal/templates/pkce-auth-msg.html @@ -1,88 +1,93 @@ + - + + + + NetBird Login + + + - NetBird Login Successful + -
- -
- {{ if .Error }} - - - - -
-
- Login failed +
+
+ + +
+ + + + + + + + + + + + + + + + + + +
- {{ .Error }}. -
- {{ else }} - - - - -
-
- Login successful + +
+ +
+ + {{ if .Error }} + +
+ + + + +
+ {{ else }} + +
+ + + + +
+ {{ end }} + + +
+ {{ if .Error }} +

Login Failed

+ {{ else }} +

Login Successful

+ {{ end }} +
+ + + {{ if .Error }} +
+ {{ .Error }} +
+ {{ else }} + +
+ Your device is now registered and logged in to NetBird. You can now close this window. +
+ {{ end }} + +
- Your device is now registered and logged in to NetBird. -
- You can now close this window.
- {{ end }}
+ diff --git a/client/internal/templates/pkce_auth_msg_test.go b/client/internal/templates/pkce_auth_msg_test.go new file mode 100644 index 000000000..75b1c9e76 --- /dev/null +++ b/client/internal/templates/pkce_auth_msg_test.go @@ -0,0 +1,299 @@ +package templates + +import ( + "html/template" + "os" + "path/filepath" + "testing" +) + +func TestPKCEAuthMsgTemplate(t *testing.T) { + tests := []struct { + name string + data map[string]string + outputFile string + expectedTitle string + expectedInContent []string + notExpectedInContent []string + }{ + { + name: "error_state", + data: map[string]string{ + "Error": "authentication failed: invalid state", + }, + outputFile: "pkce-auth-error.html", + expectedTitle: "Login Failed", + expectedInContent: []string{ + "authentication failed: invalid state", + "Login Failed", + }, + notExpectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird", + }, + }, + { + name: "success_state", + data: map[string]string{ + // No error field means success + }, + outputFile: "pkce-auth-success.html", + expectedTitle: "Login Successful", + expectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird. You can now close this window.", + }, + notExpectedInContent: []string{ + "Login Failed", + }, + }, + { + name: "error_state_timeout", + data: map[string]string{ + "Error": "authentication timeout: request expired after 5 minutes", + }, + outputFile: "pkce-auth-timeout.html", + expectedTitle: "Login Failed", + expectedInContent: []string{ + "authentication timeout: request expired after 5 minutes", + "Login Failed", + }, + notExpectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse the template + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + // Create temp directory for this test + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, tt.outputFile) + + // Create output file + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + + // Execute the template + if err := tmpl.Execute(file, tt.data); err != nil { + file.Close() + t.Fatalf("Failed to execute template: %v", err) + } + file.Close() + + t.Logf("Generated test output: %s", outputPath) + + // Read the generated file + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + contentStr := string(content) + + // Verify file has content + if len(contentStr) == 0 { + t.Error("Output file is empty") + } + + // Verify basic HTML structure + basicElements := []string{ + "", + "", + "", + "NetBird", + } + + for _, elem := range basicElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + + // Verify expected title + if !contains(contentStr, tt.expectedTitle) { + t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle) + } + + // Verify expected content is present + for _, expected := range tt.expectedInContent { + if !contains(contentStr, expected) { + t.Errorf("Expected HTML to contain '%s', but it was not found", expected) + } + } + + // Verify unexpected content is not present + for _, notExpected := range tt.notExpectedInContent { + if contains(contentStr, notExpected) { + t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected) + } + } + }) + } +} + +func TestPKCEAuthMsgTemplateValidation(t *testing.T) { + // Test that the template can be parsed without errors + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Template parsing failed: %v", err) + } + + // Test with empty data + t.Run("empty_data", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "empty-data.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + if err := tmpl.Execute(file, nil); err != nil { + t.Errorf("Template execution with nil data failed: %v", err) + } + }) + + // Test with error data + t.Run("with_error", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "with-error.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + data := map[string]string{ + "Error": "test error message", + } + if err := tmpl.Execute(file, data); err != nil { + t.Errorf("Template execution with error data failed: %v", err) + } + }) +} + +func TestPKCEAuthMsgTemplateContent(t *testing.T) { + // Test that the template contains expected elements + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Template parsing failed: %v", err) + } + + t.Run("success_content", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "success.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + data := map[string]string{} + if err := tmpl.Execute(file, data); err != nil { + t.Fatalf("Template execution failed: %v", err) + } + + // Read the file and verify it contains expected content + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + // Check for success indicators + contentStr := string(content) + if len(contentStr) == 0 { + t.Error("Generated HTML is empty") + } + + // Basic HTML structure checks + requiredElements := []string{ + "", + "", + "", + "Login Successful", + "NetBird", + } + + for _, elem := range requiredElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + }) + + t.Run("error_content", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "error.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + errorMsg := "test error message" + data := map[string]string{ + "Error": errorMsg, + } + if err := tmpl.Execute(file, data); err != nil { + t.Fatalf("Template execution failed: %v", err) + } + + // Read the file and verify it contains expected content + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + // Check for error indicators + contentStr := string(content) + if len(contentStr) == 0 { + t.Error("Generated HTML is empty") + } + + // Basic HTML structure checks + requiredElements := []string{ + "", + "", + "", + "Login Failed", + errorMsg, + } + + for _, elem := range requiredElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + }) +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && containsHelper(s, substr))) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} From d71a82769c32238cf29bd1bb3bd3ec520c0c1ab9 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 17 Nov 2025 17:10:41 +0100 Subject: [PATCH 077/120] [client,management] Rewrite the SSH feature (#4015) --- .../workflows/check-license-dependencies.yml | 52 +- client/android/client.go | 2 +- client/android/preferences.go | 88 + client/cmd/root.go | 3 - client/cmd/ssh.go | 866 +++++++++- client/cmd/ssh_exec_unix.go | 74 + client/cmd/ssh_sftp_unix.go | 94 ++ client/cmd/ssh_sftp_windows.go | 94 ++ client/cmd/ssh_test.go | 717 ++++++++ client/cmd/status.go | 2 +- client/cmd/testutil_test.go | 3 +- client/cmd/up.go | 68 + client/embed/embed.go | 70 +- client/firewall/uspfilter/filter.go | 12 +- client/firewall/uspfilter/filter_test.go | 136 ++ .../firewall/uspfilter/nat_stateful_test.go | 85 + client/internal/acl/manager.go | 17 - client/internal/acl/manager_test.go | 67 - client/internal/connect.go | 38 +- client/internal/debug/debug.go | 12 + client/internal/engine.go | 143 +- client/internal/engine_ssh.go | 355 ++++ client/internal/engine_test.go | 134 +- client/internal/login.go | 10 + client/internal/peer/conn.go | 2 +- client/internal/peer/env.go | 4 + client/internal/peer/status.go | 19 +- client/internal/profilemanager/config.go | 128 +- client/internal/profilemanager/config_test.go | 8 +- .../internal/profilemanager/profilemanager.go | 18 + client/internal/routemanager/dynamic/route.go | 2 +- client/internal/routemanager/manager.go | 2 +- client/ios/NetBirdSDK/client.go | 2 +- client/proto/daemon.pb.go | 1458 +++++++++++++---- client/proto/daemon.proto | 177 +- client/proto/daemon_grpc.pb.go | 114 ++ client/server/jwt_cache.go | 79 + client/server/network.go | 2 +- client/server/server.go | 317 +++- client/server/server_test.go | 3 +- client/server/setconfig_test.go | 110 +- client/server/state_generic.go | 2 + client/server/state_linux.go | 2 + client/ssh/client.go | 118 -- client/ssh/client/client.go | 699 ++++++++ client/ssh/client/client_test.go | 512 ++++++ client/ssh/client/terminal_unix.go | 127 ++ client/ssh/client/terminal_windows.go | 265 +++ client/ssh/common.go | 171 ++ client/ssh/config/manager.go | 282 ++++ client/ssh/config/manager_test.go | 159 ++ client/ssh/config/shutdown_state.go | 22 + client/ssh/detection/detection.go | 99 ++ client/ssh/login.go | 53 - client/ssh/lookup.go | 14 - client/ssh/lookup_darwin.go | 51 - client/ssh/proxy/proxy.go | 392 +++++ client/ssh/proxy/proxy_test.go | 367 +++++ client/ssh/server.go | 280 ---- client/ssh/server/command_execution.go | 206 +++ client/ssh/server/command_execution_js.go | 52 + client/ssh/server/command_execution_unix.go | 329 ++++ .../ssh/server/command_execution_windows.go | 430 +++++ client/ssh/server/compatibility_test.go | 722 ++++++++ client/ssh/server/executor_unix.go | 253 +++ client/ssh/server/executor_unix_test.go | 262 +++ client/ssh/server/executor_windows.go | 570 +++++++ client/ssh/server/jwt_test.go | 629 +++++++ client/ssh/server/port_forwarding.go | 386 +++++ client/ssh/server/server.go | 712 ++++++++ client/ssh/server/server_config_test.go | 394 +++++ client/ssh/server/server_test.go | 441 +++++ client/ssh/server/session_handlers.go | 168 ++ client/ssh/server/session_handlers_js.go | 22 + client/ssh/server/sftp.go | 81 + client/ssh/server/sftp_js.go | 12 + client/ssh/server/sftp_test.go | 228 +++ client/ssh/server/sftp_unix.go | 71 + client/ssh/server/sftp_windows.go | 91 + client/ssh/server/shell.go | 180 ++ client/ssh/server/test.go | 45 + client/ssh/server/user_utils.go | 411 +++++ client/ssh/server/user_utils_js.go | 8 + client/ssh/server/user_utils_test.go | 908 ++++++++++ client/ssh/server/userswitching_js.go | 8 + client/ssh/server/userswitching_unix.go | 233 +++ client/ssh/server/userswitching_windows.go | 274 ++++ client/ssh/server/winpty/conpty.go | 487 ++++++ client/ssh/server/winpty/conpty_test.go | 290 ++++ client/ssh/server_mock.go | 46 - client/ssh/server_test.go | 123 -- client/ssh/{util.go => ssh.go} | 10 +- client/ssh/testutil/user_helpers.go | 172 ++ client/ssh/window_freebsd.go | 10 - client/ssh/window_unix.go | 14 - client/ssh/window_windows.go | 9 - client/status/status.go | 89 +- client/status/status_test.go | 17 +- client/system/info.go | 24 + client/ui/client_ui.go | 492 ++++-- client/wasm/cmd/main.go | 39 +- client/wasm/internal/ssh/client.go | 51 +- client/wasm/internal/ssh/key.go | 50 - go.mod | 16 +- go.sum | 56 +- management/internals/server/modules.go | 2 +- .../internals/shared/grpc/conversion.go | 67 +- management/internals/shared/grpc/server.go | 2 +- management/server/account.go | 28 +- management/server/account/manager.go | 11 +- management/server/account_test.go | 54 +- management/server/auth/manager.go | 17 +- management/server/auth/manager_mock.go | 15 +- management/server/auth/manager_test.go | 12 +- management/server/context/auth.go | 38 +- management/server/dns_test.go | 2 +- .../accounts/accounts_handler_test.go | 3 +- .../http/handlers/dns/dns_settings_handler.go | 2 +- .../handlers/dns/dns_settings_handler_test.go | 5 +- .../handlers/dns/nameservers_handler_test.go | 3 +- .../handlers/events/events_handler_test.go | 5 +- .../http/handlers/groups/groups_handler.go | 2 +- .../handlers/groups/groups_handler_test.go | 13 +- .../server/http/handlers/networks/handler.go | 6 +- .../handlers/networks/resources_handler.go | 4 +- .../http/handlers/networks/routers_handler.go | 4 +- .../http/handlers/peers/peers_handler_test.go | 7 +- .../policies/geolocation_handler_test.go | 7 +- .../handlers/policies/geolocations_handler.go | 4 +- .../handlers/policies/policies_handler.go | 2 +- .../policies/policies_handler_test.go | 9 +- .../policies/posture_checks_handler.go | 2 +- .../policies/posture_checks_handler_test.go | 7 +- .../handlers/routes/routes_handler_test.go | 3 +- .../handlers/setup_keys/setupkeys_handler.go | 2 +- .../setup_keys/setupkeys_handler_test.go | 7 +- .../server/http/handlers/users/pat_handler.go | 2 +- .../http/handlers/users/pat_handler_test.go | 7 +- .../http/handlers/users/users_handler_test.go | 35 +- .../server/http/middleware/auth_middleware.go | 35 +- .../http/middleware/auth_middleware_test.go | 63 +- .../testing/testing_tools/channel/channel.go | 15 +- management/server/idp/pocketid_test.go | 199 ++- management/server/management_proto_test.go | 2 +- management/server/management_test.go | 1 + management/server/mock_server/account_mock.go | 16 +- management/server/nameserver_test.go | 2 +- .../server/networks/resources/manager_test.go | 2 +- .../networks/resources/types/resource.go | 2 +- .../server/networks/routers/manager_test.go | 2 +- .../server/networks/routers/types/router.go | 2 +- management/server/peer_test.go | 8 +- management/server/posture/checks.go | 2 +- management/server/route_test.go | 2 +- .../server/types/route_firewall_rule.go | 2 +- management/server/user.go | 13 +- management/server/user_test.go | 26 +- relay/server/peer.go | 4 +- .../server => shared}/auth/jwt/extractor.go | 8 +- .../server => shared}/auth/jwt/validator.go | 0 shared/auth/user.go | 28 + shared/context/keys.go | 2 +- shared/management/client/client_test.go | 3 +- shared/management/operations/operation.go | 2 +- shared/management/proto/management.pb.go | 1414 +++++++++------- shared/management/proto/management.proto | 18 + shared/relay/client/dialer/quic/quic.go | 2 +- shared/relay/client/dialer/ws/ws.go | 2 +- shared/relay/constants.go | 2 +- version/url_windows.go | 6 +- 170 files changed, 18744 insertions(+), 2853 deletions(-) create mode 100644 client/cmd/ssh_exec_unix.go create mode 100644 client/cmd/ssh_sftp_unix.go create mode 100644 client/cmd/ssh_sftp_windows.go create mode 100644 client/cmd/ssh_test.go create mode 100644 client/firewall/uspfilter/nat_stateful_test.go create mode 100644 client/internal/engine_ssh.go create mode 100644 client/server/jwt_cache.go delete mode 100644 client/ssh/client.go create mode 100644 client/ssh/client/client.go create mode 100644 client/ssh/client/client_test.go create mode 100644 client/ssh/client/terminal_unix.go create mode 100644 client/ssh/client/terminal_windows.go create mode 100644 client/ssh/common.go create mode 100644 client/ssh/config/manager.go create mode 100644 client/ssh/config/manager_test.go create mode 100644 client/ssh/config/shutdown_state.go create mode 100644 client/ssh/detection/detection.go delete mode 100644 client/ssh/login.go delete mode 100644 client/ssh/lookup.go delete mode 100644 client/ssh/lookup_darwin.go create mode 100644 client/ssh/proxy/proxy.go create mode 100644 client/ssh/proxy/proxy_test.go delete mode 100644 client/ssh/server.go create mode 100644 client/ssh/server/command_execution.go create mode 100644 client/ssh/server/command_execution_js.go create mode 100644 client/ssh/server/command_execution_unix.go create mode 100644 client/ssh/server/command_execution_windows.go create mode 100644 client/ssh/server/compatibility_test.go create mode 100644 client/ssh/server/executor_unix.go create mode 100644 client/ssh/server/executor_unix_test.go create mode 100644 client/ssh/server/executor_windows.go create mode 100644 client/ssh/server/jwt_test.go create mode 100644 client/ssh/server/port_forwarding.go create mode 100644 client/ssh/server/server.go create mode 100644 client/ssh/server/server_config_test.go create mode 100644 client/ssh/server/server_test.go create mode 100644 client/ssh/server/session_handlers.go create mode 100644 client/ssh/server/session_handlers_js.go create mode 100644 client/ssh/server/sftp.go create mode 100644 client/ssh/server/sftp_js.go create mode 100644 client/ssh/server/sftp_test.go create mode 100644 client/ssh/server/sftp_unix.go create mode 100644 client/ssh/server/sftp_windows.go create mode 100644 client/ssh/server/shell.go create mode 100644 client/ssh/server/test.go create mode 100644 client/ssh/server/user_utils.go create mode 100644 client/ssh/server/user_utils_js.go create mode 100644 client/ssh/server/user_utils_test.go create mode 100644 client/ssh/server/userswitching_js.go create mode 100644 client/ssh/server/userswitching_unix.go create mode 100644 client/ssh/server/userswitching_windows.go create mode 100644 client/ssh/server/winpty/conpty.go create mode 100644 client/ssh/server/winpty/conpty_test.go delete mode 100644 client/ssh/server_mock.go delete mode 100644 client/ssh/server_test.go rename client/ssh/{util.go => ssh.go} (86%) create mode 100644 client/ssh/testutil/user_helpers.go delete mode 100644 client/ssh/window_freebsd.go delete mode 100644 client/ssh/window_unix.go delete mode 100644 client/ssh/window_windows.go delete mode 100644 client/wasm/internal/ssh/key.go rename {management/server => shared}/auth/jwt/extractor.go (92%) rename {management/server => shared}/auth/jwt/validator.go (100%) create mode 100644 shared/auth/user.go diff --git a/.github/workflows/check-license-dependencies.yml b/.github/workflows/check-license-dependencies.yml index 2a3e7d424..543ba2ab2 100644 --- a/.github/workflows/check-license-dependencies.yml +++ b/.github/workflows/check-license-dependencies.yml @@ -19,35 +19,37 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4 - - name: Check for problematic license dependencies - run: | - echo "Checking for dependencies on management/, signal/, and relay/ packages..." + - name: Check for problematic license dependencies + run: | + echo "Checking for dependencies on management/, signal/, and relay/ packages..." + echo "" - # Find all directories except the problematic ones and system dirs - FOUND_ISSUES=0 - find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do - echo "=== Checking $dir ===" - # Search for problematic imports, excluding test files - RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) - if [ ! -z "$RESULTS" ]; then - echo "❌ Found problematic dependencies:" - echo "$RESULTS" - FOUND_ISSUES=1 + # Find all directories except the problematic ones and system dirs + FOUND_ISSUES=0 + while IFS= read -r dir; do + echo "=== Checking $dir ===" + # Search for problematic imports, excluding test files + RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) + if [ -n "$RESULTS" ]; then + echo "❌ Found problematic dependencies:" + echo "$RESULTS" + FOUND_ISSUES=1 + else + echo "✓ No problematic dependencies found" + fi + done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort) + + echo "" + if [ $FOUND_ISSUES -eq 1 ]; then + echo "❌ Found dependencies on management/, signal/, or relay/ packages" + echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" + exit 1 else - echo "✓ No problematic dependencies found" + echo "" + echo "✅ All internal license dependencies are clean" fi - done - if [ $FOUND_ISSUES -eq 1 ]; then - echo "" - echo "❌ Found dependencies on management/, signal/, or relay/ packages" - echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" - exit 1 - else - echo "" - echo "✅ All internal license dependencies are clean" - fi check-external-licenses: name: Check External GPL/AGPL Licenses diff --git a/client/android/client.go b/client/android/client.go index d2d0c37f6..86fb1445d 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -17,9 +17,9 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/client/net" ) // ConnectionListener export internal Listener for mobile diff --git a/client/android/preferences.go b/client/android/preferences.go index 9a5d6bb21..c3c8eb3fb 100644 --- a/client/android/preferences.go +++ b/client/android/preferences.go @@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) { p.configInput.ServerSSHAllowed = &allowed } +// GetEnableSSHRoot reads SSH root login setting from config file +func (p *Preferences) GetEnableSSHRoot() (bool, error) { + if p.configInput.EnableSSHRoot != nil { + return *p.configInput.EnableSSHRoot, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHRoot == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHRoot, err +} + +// SetEnableSSHRoot stores the given value and waits for commit +func (p *Preferences) SetEnableSSHRoot(enabled bool) { + p.configInput.EnableSSHRoot = &enabled +} + +// GetEnableSSHSFTP reads SSH SFTP setting from config file +func (p *Preferences) GetEnableSSHSFTP() (bool, error) { + if p.configInput.EnableSSHSFTP != nil { + return *p.configInput.EnableSSHSFTP, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHSFTP == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHSFTP, err +} + +// SetEnableSSHSFTP stores the given value and waits for commit +func (p *Preferences) SetEnableSSHSFTP(enabled bool) { + p.configInput.EnableSSHSFTP = &enabled +} + +// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file +func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) { + if p.configInput.EnableSSHLocalPortForwarding != nil { + return *p.configInput.EnableSSHLocalPortForwarding, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHLocalPortForwarding == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHLocalPortForwarding, err +} + +// SetEnableSSHLocalPortForwarding stores the given value and waits for commit +func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) { + p.configInput.EnableSSHLocalPortForwarding = &enabled +} + +// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file +func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) { + if p.configInput.EnableSSHRemotePortForwarding != nil { + return *p.configInput.EnableSSHRemotePortForwarding, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + if cfg.EnableSSHRemotePortForwarding == nil { + // Default to false for security on Android + return false, nil + } + return *cfg.EnableSSHRemotePortForwarding, err +} + +// SetEnableSSHRemotePortForwarding stores the given value and waits for commit +func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) { + p.configInput.EnableSSHRemotePortForwarding = &enabled +} + // GetBlockInbound reads block inbound setting from config file func (p *Preferences) GetBlockInbound() (bool, error) { if p.configInput.BlockInbound != nil { diff --git a/client/cmd/root.go b/client/cmd/root.go index 11e5228f1..9f2eb109c 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -35,7 +35,6 @@ const ( wireguardPortFlag = "wireguard-port" networkMonitorFlag = "network-monitor" disableAutoConnectFlag = "disable-auto-connect" - serverSSHAllowedFlag = "allow-server-ssh" extraIFaceBlackListFlag = "extra-iface-blacklist" dnsRouteIntervalFlag = "dns-router-interval" enableLazyConnectionFlag = "enable-lazy-connection" @@ -64,7 +63,6 @@ var ( customDNSAddress string rosenpassEnabled bool rosenpassPermissive bool - serverSSHAllowed bool interfaceName string wireguardPort uint16 networkMonitor bool @@ -176,7 +174,6 @@ func init() { ) upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") - upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.") diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 5358ddacb..70c7dbcff 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -3,125 +3,809 @@ package cmd import ( "context" "errors" + "flag" "fmt" + "net" "os" "os/signal" + "os/user" + "slices" + "strconv" "strings" "syscall" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.org/x/crypto/ssh" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/client/internal/profilemanager" - nbssh "github.com/netbirdio/netbird/client/ssh" + sshclient "github.com/netbirdio/netbird/client/ssh/client" + "github.com/netbirdio/netbird/client/ssh/detection" + sshproxy "github.com/netbirdio/netbird/client/ssh/proxy" + sshserver "github.com/netbirdio/netbird/client/ssh/server" "github.com/netbirdio/netbird/util" ) -var ( - port int - userName = "root" - host string +const ( + sshUsernameDesc = "SSH username" + hostArgumentRequired = "host argument required" + + serverSSHAllowedFlag = "allow-server-ssh" + enableSSHRootFlag = "enable-ssh-root" + enableSSHSFTPFlag = "enable-ssh-sftp" + enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding" + enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding" + disableSSHAuthFlag = "disable-ssh-auth" + sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl" ) -var sshCmd = &cobra.Command{ - Use: "ssh [user@]host", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return errors.New("requires a host argument") - } +var ( + port int + username string + host string + command string + localForwards []string + remoteForwards []string + strictHostKeyChecking bool + knownHostsFile string + identityFile string + skipCachedToken bool + requestPTY bool +) - split := strings.Split(args[0], "@") - if len(split) == 2 { - userName = split[0] - host = split[1] - } else { - host = args[0] - } +var ( + serverSSHAllowed bool + enableSSHRoot bool + enableSSHSFTP bool + enableSSHLocalPortForward bool + enableSSHRemotePortForward bool + disableSSHAuth bool + sshJWTCacheTTL int +) - return nil - }, - Short: "Connect to a remote SSH server", - RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - SetFlagsFromEnvVars(cmd) +func init() { + upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer") + upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server") + upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server") + upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server") + upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server") + upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication") + upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)") - cmd.SetOut(cmd.OutOrStdout()) + sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port") + sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc) + sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)") + sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation") + sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)") + sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)") + sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)") + _ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used") + sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") - err := util.InitLog(logLevel, util.LogConsole) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } + sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") + sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport") - if !util.IsAdmin() { - cmd.Printf("error: you must have Administrator privileges to run this command\n") - return nil - } - - ctx := internal.CtxInitState(cmd.Context()) - - sm := profilemanager.NewServiceManager(configPath) - activeProf, err := sm.GetActiveProfileState() - if err != nil { - return fmt.Errorf("get active profile: %v", err) - } - profPath, err := activeProf.FilePath() - if err != nil { - return fmt.Errorf("get active profile path: %v", err) - } - - config, err := profilemanager.ReadConfig(profPath) - if err != nil { - return fmt.Errorf("read profile config: %v", err) - } - - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT) - sshctx, cancel := context.WithCancel(ctx) - - go func() { - // blocking - if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { - cmd.Printf("Error: %v\n", err) - os.Exit(1) - } - cancel() - }() - - select { - case <-sig: - cancel() - case <-sshctx.Done(): - } - - return nil - }, + sshCmd.AddCommand(sshSftpCmd) + sshCmd.AddCommand(sshProxyCmd) + sshCmd.AddCommand(sshDetectCmd) } -func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error { - c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey) - if err != nil { - cmd.Printf("Error: %v\n", err) - cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + - "\nYou can verify the connection by running:\n\n" + - " netbird status\n\n") - return err - } - go func() { - <-ctx.Done() - err = c.Close() - if err != nil { - return +var sshCmd = &cobra.Command{ + Use: "ssh [flags] [user@]host [command]", + Short: "Connect to a NetBird peer via SSH", + Long: `Connect to a NetBird peer using SSH with support for port forwarding. + +Port Forwarding: + -L [bind_address:]port:host:hostport Local port forwarding + -L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket + -R [bind_address:]port:host:hostport Remote port forwarding + -R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket + +SSH Options: + -p, --port int Remote SSH port (default 22) + -u, --user string SSH username + --login string SSH username (alias for --user) + -t, --tty Force pseudo-terminal allocation + --strict-host-key-checking Enable strict host key checking (default: true) + -o, --known-hosts string Path to known_hosts file + +Examples: + netbird ssh peer-hostname + netbird ssh root@peer-hostname + netbird ssh --login root peer-hostname + netbird ssh peer-hostname ls -la + netbird ssh peer-hostname whoami + netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen + netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo + netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding + netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding + netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces + netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`, + DisableFlagParsing: true, + Args: validateSSHArgsWithoutFlagParsing, + RunE: sshFn, + Aliases: []string{"ssh"}, +} + +func sshFn(cmd *cobra.Command, args []string) error { + for _, arg := range args { + if arg == "-h" || arg == "--help" { + return cmd.Help() } + } + + SetFlagsFromEnvVars(rootCmd) + SetFlagsFromEnvVars(cmd) + + cmd.SetOut(cmd.OutOrStdout()) + + logOutput := "console" + if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { + logOutput = firstLogFile + } + if err := util.InitLog(logLevel, logOutput); err != nil { + return fmt.Errorf("init log: %w", err) + } + + ctx := internal.CtxInitState(cmd.Context()) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT) + sshctx, cancel := context.WithCancel(ctx) + + errCh := make(chan error, 1) + go func() { + if err := runSSH(sshctx, host, cmd); err != nil { + errCh <- err + } + cancel() }() - err = c.OpenTerminal() - if err != nil { + select { + case <-sig: + cancel() + <-sshctx.Done() + return nil + case err := <-errCh: return err + case <-sshctx.Done(): } return nil } -func init() { - sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort)) +// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes +func getEnvOrDefault(flagName, defaultValue string) string { + if envValue := os.Getenv("WT_" + flagName); envValue != "" { + return envValue + } + if envValue := os.Getenv("NB_" + flagName); envValue != "" { + return envValue + } + return defaultValue +} + +// resetSSHGlobals sets SSH globals to their default values +func resetSSHGlobals() { + port = sshserver.DefaultSSHPort + username = "" + host = "" + command = "" + localForwards = nil + remoteForwards = nil + strictHostKeyChecking = true + knownHostsFile = "" + identityFile = "" +} + +// parseCustomSSHFlags extracts -L, -R flags and returns filtered args +func parseCustomSSHFlags(args []string) ([]string, []string, []string) { + var localForwardFlags []string + var remoteForwardFlags []string + var filteredArgs []string + + for i := 0; i < len(args); i++ { + arg := args[i] + switch { + case strings.HasPrefix(arg, "-L"): + localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags) + case strings.HasPrefix(arg, "-R"): + remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags) + default: + filteredArgs = append(filteredArgs, arg) + } + } + + return filteredArgs, localForwardFlags, remoteForwardFlags +} + +func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) { + if arg == "-L" || arg == "-R" { + if i+1 < len(args) { + flags = append(flags, args[i+1]) + i++ + } + } else if len(arg) > 2 { + flags = append(flags, arg[2:]) + } + return flags, i +} + +// extractGlobalFlags parses global flags that were passed before 'ssh' command +func extractGlobalFlags(args []string) { + sshPos := findSSHCommandPosition(args) + if sshPos == -1 { + return + } + + globalArgs := args[:sshPos] + parseGlobalArgs(globalArgs) +} + +// findSSHCommandPosition locates the 'ssh' command in the argument list +func findSSHCommandPosition(args []string) int { + for i, arg := range args { + if arg == "ssh" { + return i + } + } + return -1 +} + +const ( + configFlag = "config" + logLevelFlag = "log-level" + logFileFlag = "log-file" +) + +// parseGlobalArgs processes the global arguments and sets the corresponding variables +func parseGlobalArgs(globalArgs []string) { + flagHandlers := map[string]func(string){ + configFlag: func(value string) { configPath = value }, + logLevelFlag: func(value string) { logLevel = value }, + logFileFlag: func(value string) { + if !slices.Contains(logFiles, value) { + logFiles = append(logFiles, value) + } + }, + } + + shortFlags := map[string]string{ + "c": configFlag, + "l": logLevelFlag, + } + + for i := 0; i < len(globalArgs); i++ { + arg := globalArgs[i] + + if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled { + i = nextIndex + } + } +} + +// parseFlag handles generic flag parsing for both long and short forms +func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) { + if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found { + flagHandlers[parsedValue.flagName](parsedValue.value) + return true, currentIndex + } + + if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found { + flagHandlers[parsedValue.flagName](parsedValue.value) + return true, currentIndex + 1 + } + + return false, currentIndex +} + +type parsedFlag struct { + flagName string + value string +} + +// parseEqualsFormat handles --flag=value and -f=value formats +func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) { + if !strings.Contains(arg, "=") { + return parsedFlag{}, false + } + + parts := strings.SplitN(arg, "=", 2) + if len(parts) != 2 { + return parsedFlag{}, false + } + + if strings.HasPrefix(parts[0], "--") { + flagName := strings.TrimPrefix(parts[0], "--") + if _, exists := flagHandlers[flagName]; exists { + return parsedFlag{flagName: flagName, value: parts[1]}, true + } + } + + if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 { + shortFlag := strings.TrimPrefix(parts[0], "-") + if longFlag, exists := shortFlags[shortFlag]; exists { + if _, exists := flagHandlers[longFlag]; exists { + return parsedFlag{flagName: longFlag, value: parts[1]}, true + } + } + } + + return parsedFlag{}, false +} + +// parseSpacedFormat handles --flag value and -f value formats +func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) { + if currentIndex+1 >= len(args) { + return parsedFlag{}, false + } + + if strings.HasPrefix(arg, "--") { + flagName := strings.TrimPrefix(arg, "--") + if _, exists := flagHandlers[flagName]; exists { + return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true + } + } + + if strings.HasPrefix(arg, "-") && len(arg) == 2 { + shortFlag := strings.TrimPrefix(arg, "-") + if longFlag, exists := shortFlags[shortFlag]; exists { + if _, exists := flagHandlers[longFlag]; exists { + return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true + } + } + } + + return parsedFlag{}, false +} + +// createSSHFlagSet creates and configures the flag set for SSH command parsing +// sshFlags contains all SSH-related flags and parameters +type sshFlags struct { + Port int + Username string + Login string + RequestPTY bool + StrictHostKeyChecking bool + KnownHostsFile string + IdentityFile string + SkipCachedToken bool + ConfigPath string + LogLevel string + LocalForwards []string + RemoteForwards []string + Host string + Command string +} + +func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { + defaultConfigPath := getEnvOrDefault("CONFIG", configPath) + defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + + fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError) + fs.SetOutput(nil) + + flags := &sshFlags{} + + fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port") + fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port") + fs.StringVar(&flags.Username, "u", "", sshUsernameDesc) + fs.StringVar(&flags.Username, "user", "", sshUsernameDesc) + fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)") + fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation") + fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation") + + fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking") + fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file") + fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file") + fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file") + fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file") + fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + + fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location") + fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location") + fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level") + fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level") + + return fs, flags +} + +func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error { + if len(args) < 1 { + return errors.New(hostArgumentRequired) + } + + resetSSHGlobals() + + if len(os.Args) > 2 { + extractGlobalFlags(os.Args[1:]) + } + + filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args) + + fs, flags := createSSHFlagSet() + + if err := fs.Parse(filteredArgs); err != nil { + if errors.Is(err, flag.ErrHelp) { + return nil + } + return err + } + + remaining := fs.Args() + if len(remaining) < 1 { + return errors.New(hostArgumentRequired) + } + + port = flags.Port + if flags.Username != "" { + username = flags.Username + } else if flags.Login != "" { + username = flags.Login + } + + requestPTY = flags.RequestPTY + strictHostKeyChecking = flags.StrictHostKeyChecking + knownHostsFile = flags.KnownHostsFile + identityFile = flags.IdentityFile + skipCachedToken = flags.SkipCachedToken + + if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) { + configPath = flags.ConfigPath + } + if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) { + logLevel = flags.LogLevel + } + + localForwards = localForwardFlags + remoteForwards = remoteForwardFlags + + return parseHostnameAndCommand(remaining) +} + +func parseHostnameAndCommand(args []string) error { + if len(args) < 1 { + return errors.New(hostArgumentRequired) + } + + arg := args[0] + if strings.Contains(arg, "@") { + parts := strings.SplitN(arg, "@", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return errors.New("invalid user@host format") + } + if username == "" { + username = parts[0] + } + host = parts[1] + } else { + host = arg + } + + if username == "" { + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + username = sudoUser + } else if currentUser, err := user.Current(); err == nil { + username = currentUser.Username + } else { + username = "root" + } + } + + // Everything after hostname becomes the command + if len(args) > 1 { + command = strings.Join(args[1:], " ") + } + + return nil +} + +func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { + target := fmt.Sprintf("%s:%d", addr, port) + c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{ + KnownHostsFile: knownHostsFile, + IdentityFile: identityFile, + DaemonAddr: daemonAddr, + SkipCachedToken: skipCachedToken, + InsecureSkipVerify: !strictHostKeyChecking, + }) + + if err != nil { + cmd.Printf("Failed to connect to %s@%s\n", username, target) + cmd.Printf("\nTroubleshooting steps:\n") + cmd.Printf(" 1. Check peer connectivity: netbird status -d\n") + cmd.Printf(" 2. Verify SSH server is enabled on the peer\n") + cmd.Printf(" 3. Ensure correct hostname/IP is used\n") + return fmt.Errorf("dial %s: %w", target, err) + } + + sshCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-sshCtx.Done() + if err := c.Close(); err != nil { + cmd.Printf("Error closing SSH connection: %v\n", err) + } + }() + + if err := startPortForwarding(sshCtx, c, cmd); err != nil { + return fmt.Errorf("start port forwarding: %w", err) + } + + if command != "" { + return executeSSHCommand(sshCtx, c, command) + } + return openSSHTerminal(sshCtx, c) +} + +// executeSSHCommand executes a command over SSH. +func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error { + var err error + if requestPTY { + err = c.ExecuteCommandWithPTY(ctx, command) + } else { + err = c.ExecuteCommandWithIO(ctx, command) + } + + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + + var exitErr *ssh.ExitError + if errors.As(err, &exitErr) { + os.Exit(exitErr.ExitStatus()) + } + + var exitMissingErr *ssh.ExitMissingError + if errors.As(err, &exitMissingErr) { + log.Debugf("Remote command exited without exit status: %v", err) + return nil + } + + return fmt.Errorf("execute command: %w", err) + } + return nil +} + +// openSSHTerminal opens an interactive SSH terminal. +func openSSHTerminal(ctx context.Context, c *sshclient.Client) error { + if err := c.OpenTerminal(ctx); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + + var exitMissingErr *ssh.ExitMissingError + if errors.As(err, &exitMissingErr) { + log.Debugf("Remote terminal exited without exit status: %v", err) + return nil + } + + return fmt.Errorf("open terminal: %w", err) + } + return nil +} + +// startPortForwarding starts local and remote port forwarding based on command line flags +func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error { + for _, forward := range localForwards { + if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil { + return fmt.Errorf("local port forward %s: %w", forward, err) + } + } + + for _, forward := range remoteForwards { + if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil { + return fmt.Errorf("remote port forward %s: %w", forward, err) + } + } + + return nil +} + +// parseAndStartLocalForward parses and starts a local port forward (-L) +func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error { + localAddr, remoteAddr, err := parsePortForwardSpec(forward) + if err != nil { + return err + } + + cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr) + + go func() { + if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) { + cmd.Printf("Local port forward error: %v\n", err) + } + }() + + return nil +} + +// parseAndStartRemoteForward parses and starts a remote port forward (-R) +func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error { + remoteAddr, localAddr, err := parsePortForwardSpec(forward) + if err != nil { + return err + } + + cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr) + + go func() { + if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) { + cmd.Printf("Remote port forward error: %v\n", err) + } + }() + + return nil +} + +// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80". +// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket". +func parsePortForwardSpec(spec string) (string, string, error) { + // Support formats: + // port:host:hostport -> localhost:port -> host:hostport + // host:port:host:hostport -> host:port -> host:hostport + // [host]:port:host:hostport -> [host]:port -> host:hostport + // port:unix_socket_path -> localhost:port -> unix_socket_path + // host:port:unix_socket_path -> host:port -> unix_socket_path + + if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") { + return parseIPv6ForwardSpec(spec) + } + + parts := strings.Split(spec, ":") + if len(parts) < 2 { + return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec) + } + + switch len(parts) { + case 2: + return parseTwoPartForwardSpec(parts, spec) + case 3: + return parseThreePartForwardSpec(parts) + case 4: + return parseFourPartForwardSpec(parts) + default: + return "", "", fmt.Errorf("invalid port forward specification: %s", spec) + } +} + +// parseTwoPartForwardSpec handles "port:unix_socket" format. +func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) { + if isUnixSocket(parts[1]) { + localAddr := "localhost:" + parts[0] + remoteAddr := parts[1] + return localAddr, remoteAddr, nil + } + return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec) +} + +// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats. +func parseThreePartForwardSpec(parts []string) (string, string, error) { + if isUnixSocket(parts[2]) { + localHost := normalizeLocalHost(parts[0]) + localAddr := localHost + ":" + parts[1] + remoteAddr := parts[2] + return localAddr, remoteAddr, nil + } + localAddr := "localhost:" + parts[0] + remoteAddr := parts[1] + ":" + parts[2] + return localAddr, remoteAddr, nil +} + +// parseFourPartForwardSpec handles "host:port:host:hostport" format. +func parseFourPartForwardSpec(parts []string) (string, string, error) { + localHost := normalizeLocalHost(parts[0]) + localAddr := localHost + ":" + parts[1] + remoteAddr := parts[2] + ":" + parts[3] + return localAddr, remoteAddr, nil +} + +// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format. +func parseIPv6ForwardSpec(spec string) (string, string, error) { + idx := strings.Index(spec, "]:") + if idx == -1 { + return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec) + } + + ipv6Host := spec[:idx+1] + remaining := spec[idx+2:] + + parts := strings.Split(remaining, ":") + if len(parts) != 3 { + return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec) + } + + localAddr := ipv6Host + ":" + parts[0] + remoteAddr := parts[1] + ":" + parts[2] + return localAddr, remoteAddr, nil +} + +// isUnixSocket checks if a path is a Unix socket path. +func isUnixSocket(path string) bool { + return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") +} + +// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. +func normalizeLocalHost(host string) string { + if host == "*" { + return "0.0.0.0" + } + return host +} + +var sshProxyCmd = &cobra.Command{ + Use: "proxy ", + Short: "Internal SSH proxy for native SSH client integration", + Long: "Internal command used by SSH ProxyCommand to handle JWT authentication", + Hidden: true, + Args: cobra.ExactArgs(2), + RunE: sshProxyFn, +} + +func sshProxyFn(cmd *cobra.Command, args []string) error { + logOutput := "console" + if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { + logOutput = firstLogFile + } + if err := util.InitLog(logLevel, logOutput); err != nil { + return fmt.Errorf("init log: %w", err) + } + + host := args[0] + portStr := args[1] + + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("invalid port: %s", portStr) + } + + proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr()) + if err != nil { + return fmt.Errorf("create SSH proxy: %w", err) + } + defer func() { + if err := proxy.Close(); err != nil { + log.Debugf("close SSH proxy: %v", err) + } + }() + + if err := proxy.Connect(cmd.Context()); err != nil { + return fmt.Errorf("SSH proxy: %w", err) + } + + return nil +} + +var sshDetectCmd = &cobra.Command{ + Use: "detect ", + Short: "Detect if a host is running NetBird SSH", + Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH", + Hidden: true, + Args: cobra.ExactArgs(2), + RunE: sshDetectFn, +} + +func sshDetectFn(cmd *cobra.Command, args []string) error { + if err := util.InitLog(logLevel, "console"); err != nil { + os.Exit(detection.ServerTypeRegular.ExitCode()) + } + + host := args[0] + portStr := args[1] + + port, err := strconv.Atoi(portStr) + if err != nil { + os.Exit(detection.ServerTypeRegular.ExitCode()) + } + + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port) + if err != nil { + os.Exit(detection.ServerTypeRegular.ExitCode()) + } + + os.Exit(serverType.ExitCode()) + return nil } diff --git a/client/cmd/ssh_exec_unix.go b/client/cmd/ssh_exec_unix.go new file mode 100644 index 000000000..2412f072c --- /dev/null +++ b/client/cmd/ssh_exec_unix.go @@ -0,0 +1,74 @@ +//go:build unix + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + sshserver "github.com/netbirdio/netbird/client/ssh/server" +) + +var ( + sshExecUID uint32 + sshExecGID uint32 + sshExecGroups []uint + sshExecWorkingDir string + sshExecShell string + sshExecCommand string + sshExecPTY bool +) + +// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping +var sshExecCmd = &cobra.Command{ + Use: "exec", + Short: "Internal SSH execution with privilege dropping (hidden)", + Hidden: true, + RunE: runSSHExec, +} + +func init() { + sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID") + sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID") + sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)") + sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory") + sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute") + sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)") + sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute") + + if err := sshExecCmd.MarkFlagRequired("uid"); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err) + os.Exit(1) + } + if err := sshExecCmd.MarkFlagRequired("gid"); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err) + os.Exit(1) + } + + sshCmd.AddCommand(sshExecCmd) +} + +// runSSHExec handles the SSH exec subcommand execution. +func runSSHExec(cmd *cobra.Command, _ []string) error { + privilegeDropper := sshserver.NewPrivilegeDropper() + + var groups []uint32 + for _, groupInt := range sshExecGroups { + groups = append(groups, uint32(groupInt)) + } + + config := sshserver.ExecutorConfig{ + UID: sshExecUID, + GID: sshExecGID, + Groups: groups, + WorkingDir: sshExecWorkingDir, + Shell: sshExecShell, + Command: sshExecCommand, + PTY: sshExecPTY, + } + + privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config) + return nil +} diff --git a/client/cmd/ssh_sftp_unix.go b/client/cmd/ssh_sftp_unix.go new file mode 100644 index 000000000..c06aab017 --- /dev/null +++ b/client/cmd/ssh_sftp_unix.go @@ -0,0 +1,94 @@ +//go:build unix + +package cmd + +import ( + "errors" + "io" + "os" + + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + sshserver "github.com/netbirdio/netbird/client/ssh/server" +) + +var ( + sftpUID uint32 + sftpGID uint32 + sftpGroupsInt []uint + sftpWorkingDir string +) + +var sshSftpCmd = &cobra.Command{ + Use: "sftp", + Short: "SFTP server with privilege dropping (internal use)", + Hidden: true, + RunE: sftpMain, +} + +func init() { + sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID") + sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID") + sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)") + sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory") +} + +func sftpMain(cmd *cobra.Command, _ []string) error { + privilegeDropper := sshserver.NewPrivilegeDropper() + + var groups []uint32 + for _, groupInt := range sftpGroupsInt { + groups = append(groups, uint32(groupInt)) + } + + config := sshserver.ExecutorConfig{ + UID: sftpUID, + GID: sftpGID, + Groups: groups, + WorkingDir: sftpWorkingDir, + Shell: "", + Command: "", + } + + log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups) + + if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil { + cmd.PrintErrf("privilege drop failed: %v\n", err) + os.Exit(sshserver.ExitCodePrivilegeDropFail) + } + + if config.WorkingDir != "" { + if err := os.Chdir(config.WorkingDir); err != nil { + cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err) + } + } + + sftpServer, err := sftp.NewServer(struct { + io.Reader + io.WriteCloser + }{ + Reader: os.Stdin, + WriteCloser: os.Stdout, + }) + if err != nil { + cmd.PrintErrf("SFTP server creation failed: %v\n", err) + os.Exit(sshserver.ExitCodeShellExecFail) + } + + log.Tracef("starting SFTP server with dropped privileges") + if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) { + cmd.PrintErrf("SFTP server error: %v\n", err) + if closeErr := sftpServer.Close(); closeErr != nil { + cmd.PrintErrf("SFTP server close error: %v\n", closeErr) + } + os.Exit(sshserver.ExitCodeShellExecFail) + } + + if closeErr := sftpServer.Close(); closeErr != nil { + cmd.PrintErrf("SFTP server close error: %v\n", closeErr) + } + os.Exit(sshserver.ExitCodeSuccess) + return nil +} diff --git a/client/cmd/ssh_sftp_windows.go b/client/cmd/ssh_sftp_windows.go new file mode 100644 index 000000000..ffd2d1148 --- /dev/null +++ b/client/cmd/ssh_sftp_windows.go @@ -0,0 +1,94 @@ +//go:build windows + +package cmd + +import ( + "errors" + "fmt" + "io" + "os" + "os/user" + "strings" + + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + sshserver "github.com/netbirdio/netbird/client/ssh/server" +) + +var ( + sftpWorkingDir string + windowsUsername string + windowsDomain string +) + +var sshSftpCmd = &cobra.Command{ + Use: "sftp", + Short: "SFTP server with user switching for Windows (internal use)", + Hidden: true, + RunE: sftpMain, +} + +func init() { + sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory") + sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching") + sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching") +} + +func sftpMain(cmd *cobra.Command, _ []string) error { + return sftpMainDirect(cmd) +} + +func sftpMainDirect(cmd *cobra.Command) error { + currentUser, err := user.Current() + if err != nil { + cmd.PrintErrf("failed to get current user: %v\n", err) + os.Exit(sshserver.ExitCodeValidationFail) + } + + if windowsUsername != "" { + expectedUsername := windowsUsername + if windowsDomain != "" { + expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername) + } + if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) { + cmd.PrintErrf("user switching failed\n") + os.Exit(sshserver.ExitCodeValidationFail) + } + } + + log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name) + + if sftpWorkingDir != "" { + if err := os.Chdir(sftpWorkingDir); err != nil { + cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err) + } + } + + sftpServer, err := sftp.NewServer(struct { + io.Reader + io.WriteCloser + }{ + Reader: os.Stdin, + WriteCloser: os.Stdout, + }) + if err != nil { + cmd.PrintErrf("SFTP server creation failed: %v\n", err) + os.Exit(sshserver.ExitCodeShellExecFail) + } + + log.Debugf("starting SFTP server") + exitCode := sshserver.ExitCodeSuccess + if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) { + cmd.PrintErrf("SFTP server error: %v\n", err) + exitCode = sshserver.ExitCodeShellExecFail + } + + if err := sftpServer.Close(); err != nil { + log.Debugf("SFTP server close error: %v", err) + } + + os.Exit(exitCode) + return nil +} diff --git a/client/cmd/ssh_test.go b/client/cmd/ssh_test.go new file mode 100644 index 000000000..43291fa87 --- /dev/null +++ b/client/cmd/ssh_test.go @@ -0,0 +1,717 @@ +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSHCommand_FlagParsing(t *testing.T) { + tests := []struct { + name string + args []string + expectedHost string + expectedUser string + expectedPort int + expectedCmd string + expectError bool + }{ + { + name: "basic host", + args: []string{"hostname"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "", + }, + { + name: "user@host format", + args: []string{"user@hostname"}, + expectedHost: "hostname", + expectedUser: "user", + expectedPort: 22, + expectedCmd: "", + }, + { + name: "host with command", + args: []string{"hostname", "echo", "hello"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "echo hello", + }, + { + name: "command with flags should be preserved", + args: []string{"hostname", "ls", "-la", "/tmp"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "ls -la /tmp", + }, + { + name: "double dash separator", + args: []string{"hostname", "--", "ls", "-la"}, + expectedHost: "hostname", + expectedUser: "", + expectedPort: 22, + expectedCmd: "-- ls -la", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + // Mock command for testing + cmd := sshCmd + cmd.SetArgs(tt.args) + + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + if tt.expectedUser != "" { + assert.Equal(t, tt.expectedUser, username, "username mismatch") + } + assert.Equal(t, tt.expectedPort, port, "port mismatch") + assert.Equal(t, tt.expectedCmd, command, "command mismatch") + }) + } +} + +func TestSSHCommand_FlagConflictPrevention(t *testing.T) { + // Test that SSH flags don't conflict with command flags + tests := []struct { + name string + args []string + expectedCmd string + description string + }{ + { + name: "ls with -la flags", + args: []string{"hostname", "ls", "-la"}, + expectedCmd: "ls -la", + description: "ls flags should be passed to remote command", + }, + { + name: "grep with -r flag", + args: []string{"hostname", "grep", "-r", "pattern", "/path"}, + expectedCmd: "grep -r pattern /path", + description: "grep flags should be passed to remote command", + }, + { + name: "ps with aux flags", + args: []string{"hostname", "ps", "aux"}, + expectedCmd: "ps aux", + description: "ps flags should be passed to remote command", + }, + { + name: "command with double dash", + args: []string{"hostname", "--", "ls", "-la"}, + expectedCmd: "-- ls -la", + description: "double dash should be preserved in command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + require.NoError(t, err, "SSH args validation should succeed for valid input") + + assert.Equal(t, tt.expectedCmd, command, tt.description) + }) + } +} + +func TestSSHCommand_NonInteractiveExecution(t *testing.T) { + // Test that commands with arguments should execute the command and exit, + // not drop to an interactive shell + tests := []struct { + name string + args []string + expectedCmd string + shouldExit bool + description string + }{ + { + name: "ls command should execute and exit", + args: []string{"hostname", "ls"}, + expectedCmd: "ls", + shouldExit: true, + description: "ls command should execute and exit, not drop to shell", + }, + { + name: "ls with flags should execute and exit", + args: []string{"hostname", "ls", "-la"}, + expectedCmd: "ls -la", + shouldExit: true, + description: "ls with flags should execute and exit, not drop to shell", + }, + { + name: "pwd command should execute and exit", + args: []string{"hostname", "pwd"}, + expectedCmd: "pwd", + shouldExit: true, + description: "pwd command should execute and exit, not drop to shell", + }, + { + name: "echo command should execute and exit", + args: []string{"hostname", "echo", "hello"}, + expectedCmd: "echo hello", + shouldExit: true, + description: "echo command should execute and exit, not drop to shell", + }, + { + name: "no command should open shell", + args: []string{"hostname"}, + expectedCmd: "", + shouldExit: false, + description: "no command should open interactive shell", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + require.NoError(t, err, "SSH args validation should succeed for valid input") + + assert.Equal(t, tt.expectedCmd, command, tt.description) + + // When command is present, it should execute the command and exit + // When command is empty, it should open interactive shell + hasCommand := command != "" + assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior") + }) + } +} + +func TestSSHCommand_FlagHandling(t *testing.T) { + // Test that flags after hostname are not parsed by netbird but passed to SSH command + tests := []struct { + name string + args []string + expectedHost string + expectedCmd string + expectError bool + description string + }{ + { + name: "ls with -la flag should not be parsed by netbird", + args: []string{"debian2", "ls", "-la"}, + expectedHost: "debian2", + expectedCmd: "ls -la", + expectError: false, + description: "ls -la should be passed as SSH command, not parsed as netbird flags", + }, + { + name: "command with netbird-like flags should be passed through", + args: []string{"hostname", "echo", "--help"}, + expectedHost: "hostname", + expectedCmd: "echo --help", + expectError: false, + description: "--help should be passed to echo, not parsed by netbird", + }, + { + name: "command with -p flag should not conflict with SSH port flag", + args: []string{"hostname", "ps", "-p", "1234"}, + expectedHost: "hostname", + expectedCmd: "ps -p 1234", + expectError: false, + description: "ps -p should be passed to ps command, not parsed as port", + }, + { + name: "tar with flags should be passed through", + args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"}, + expectedHost: "hostname", + expectedCmd: "tar -czf backup.tar.gz /home", + expectError: false, + description: "tar flags should be passed to tar command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + assert.Equal(t, tt.expectedCmd, command, tt.description) + }) + } +} + +func TestSSHCommand_RegressionFlagParsing(t *testing.T) { + // Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la" + // should not parse -la as netbird flags but pass them to the SSH command + tests := []struct { + name string + args []string + expectedHost string + expectedCmd string + expectError bool + description string + }{ + { + name: "original issue: ls -la should be preserved", + args: []string{"debian2", "ls", "-la"}, + expectedHost: "debian2", + expectedCmd: "ls -la", + expectError: false, + description: "The original failing case should now work", + }, + { + name: "ls -l should be preserved", + args: []string{"hostname", "ls", "-l"}, + expectedHost: "hostname", + expectedCmd: "ls -l", + expectError: false, + description: "Single letter flags should be preserved", + }, + { + name: "SSH port flag should work", + args: []string{"-p", "2222", "hostname", "ls", "-la"}, + expectedHost: "hostname", + expectedCmd: "ls -la", + expectError: false, + description: "SSH -p flag should be parsed, command flags preserved", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + assert.Equal(t, tt.expectedCmd, command, tt.description) + + // Check port for the test case with -p flag + if len(tt.args) > 0 && tt.args[0] == "-p" { + assert.Equal(t, 2222, port, "port should be parsed from -p flag") + } + }) + } +} + +func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) { + tests := []struct { + name string + args []string + expectedHost string + expectedLocal []string + expectedRemote []string + expectError bool + description string + }{ + { + name: "local port forwarding -L", + args: []string{"-L", "8080:localhost:80", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{}, + expectError: false, + description: "Single -L flag should be parsed correctly", + }, + { + name: "remote port forwarding -R", + args: []string{"-R", "8080:localhost:80", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{}, + expectedRemote: []string{"8080:localhost:80"}, + expectError: false, + description: "Single -R flag should be parsed correctly", + }, + { + name: "multiple local port forwards", + args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"}, + expectedRemote: []string{}, + expectError: false, + description: "Multiple -L flags should be parsed correctly", + }, + { + name: "multiple remote port forwards", + args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{}, + expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"}, + expectError: false, + description: "Multiple -R flags should be parsed correctly", + }, + { + name: "mixed local and remote forwards", + args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{"9090:localhost:443"}, + expectError: false, + description: "Mixed -L and -R flags should be parsed correctly", + }, + { + name: "port forwarding with bind address", + args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"}, + expectedHost: "hostname", + expectedLocal: []string{"127.0.0.1:8080:localhost:80"}, + expectedRemote: []string{}, + expectError: false, + description: "Port forwarding with bind address should work", + }, + { + name: "port forwarding with command", + args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{}, + expectError: false, + description: "Port forwarding with command should work", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + localForwards = nil + remoteForwards = nil + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err, "SSH args validation should succeed for valid input") + assert.Equal(t, tt.expectedHost, host, "host mismatch") + // Handle nil vs empty slice comparison + if len(tt.expectedLocal) == 0 { + assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty") + } else { + assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards") + } + if len(tt.expectedRemote) == 0 { + assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty") + } else { + assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards") + } + }) + } +} + +func TestParsePortForward(t *testing.T) { + tests := []struct { + name string + spec string + expectedLocal string + expectedRemote string + expectError bool + description string + }{ + { + name: "simple port forward", + spec: "8080:localhost:80", + expectedLocal: "localhost:8080", + expectedRemote: "localhost:80", + expectError: false, + description: "Simple port:host:port format should work", + }, + { + name: "port forward with bind address", + spec: "127.0.0.1:8080:localhost:80", + expectedLocal: "127.0.0.1:8080", + expectedRemote: "localhost:80", + expectError: false, + description: "bind_address:port:host:port format should work", + }, + { + name: "port forward to different host", + spec: "8080:example.com:443", + expectedLocal: "localhost:8080", + expectedRemote: "example.com:443", + expectError: false, + description: "Forwarding to different host should work", + }, + { + name: "port forward with IPv6 (needs bracket support)", + spec: "::1:8080:localhost:80", + expectError: true, + description: "IPv6 without brackets fails as expected (feature to implement)", + }, + { + name: "invalid format - too few parts", + spec: "8080:localhost", + expectError: true, + description: "Invalid format with too few parts should fail", + }, + { + name: "invalid format - too many parts", + spec: "127.0.0.1:8080:localhost:80:extra", + expectError: true, + description: "Invalid format with too many parts should fail", + }, + { + name: "empty spec", + spec: "", + expectError: true, + description: "Empty spec should fail", + }, + { + name: "unix socket local forward", + spec: "8080:/tmp/socket", + expectedLocal: "localhost:8080", + expectedRemote: "/tmp/socket", + expectError: false, + description: "Unix socket forwarding should work", + }, + { + name: "unix socket with bind address", + spec: "127.0.0.1:8080:/tmp/socket", + expectedLocal: "127.0.0.1:8080", + expectedRemote: "/tmp/socket", + expectError: false, + description: "Unix socket with bind address should work", + }, + { + name: "wildcard bind all interfaces", + spec: "*:8080:localhost:80", + expectedLocal: "0.0.0.0:8080", + expectedRemote: "localhost:80", + expectError: false, + description: "Wildcard * should bind to all interfaces (0.0.0.0)", + }, + { + name: "wildcard for port only", + spec: "8080:*:80", + expectedLocal: "localhost:8080", + expectedRemote: "*:80", + expectError: false, + description: "Wildcard in remote host should be preserved", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec) + + if tt.expectError { + assert.Error(t, err, tt.description) + return + } + + require.NoError(t, err, tt.description) + assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address") + assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address") + }) + } +} + +func TestSSHCommand_IntegrationPortForwarding(t *testing.T) { + // Integration test for port forwarding with the actual SSH command implementation + tests := []struct { + name string + args []string + expectedHost string + expectedLocal []string + expectedRemote []string + expectedCmd string + description string + }{ + { + name: "local forward with command", + args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{}, + expectedCmd: "echo test", + description: "Local forwarding should work with commands", + }, + { + name: "remote forward with command", + args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"}, + expectedHost: "hostname", + expectedLocal: []string{}, + expectedRemote: []string{"8080:localhost:80"}, + expectedCmd: "ls -la", + description: "Remote forwarding should work with commands", + }, + { + name: "multiple forwards with user and command", + args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"}, + expectedHost: "hostname", + expectedLocal: []string{"8080:localhost:80"}, + expectedRemote: []string{"9090:localhost:443"}, + expectedCmd: "ps aux", + description: "Complex case with multiple forwards, user, and command", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + localForwards = nil + remoteForwards = nil + + cmd := sshCmd + err := validateSSHArgsWithoutFlagParsing(cmd, tt.args) + require.NoError(t, err, "SSH args validation should succeed for valid input") + + assert.Equal(t, tt.expectedHost, host, "host mismatch") + // Handle nil vs empty slice comparison + if len(tt.expectedLocal) == 0 { + assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty") + } else { + assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards") + } + if len(tt.expectedRemote) == 0 { + assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty") + } else { + assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards") + } + assert.Equal(t, tt.expectedCmd, command, tt.description+" - command") + }) + } +} + +func TestSSHCommand_ParameterIsolation(t *testing.T) { + tests := []struct { + name string + args []string + expectedCmd string + }{ + { + name: "cmd flag passed as command", + args: []string{"hostname", "--cmd", "echo test"}, + expectedCmd: "--cmd echo test", + }, + { + name: "uid flag passed as command", + args: []string{"hostname", "--uid", "1000"}, + expectedCmd: "--uid 1000", + }, + { + name: "shell flag passed as command", + args: []string{"hostname", "--shell", "/bin/bash"}, + expectedCmd: "--shell /bin/bash", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host = "" + username = "" + port = 22 + command = "" + + err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args) + require.NoError(t, err) + + assert.Equal(t, "hostname", host) + assert.Equal(t, tt.expectedCmd, command) + }) + } +} + +func TestSSHCommand_InvalidFlagRejection(t *testing.T) { + // Test that invalid flags are properly rejected and not misinterpreted as hostnames + tests := []struct { + name string + args []string + description string + }{ + { + name: "invalid long flag before hostname", + args: []string{"--invalid-flag", "hostname"}, + description: "Invalid flag should return parse error, not treat flag as hostname", + }, + { + name: "invalid short flag before hostname", + args: []string{"-x", "hostname"}, + description: "Invalid short flag should return parse error", + }, + { + name: "invalid flag with value before hostname", + args: []string{"--invalid-option=value", "hostname"}, + description: "Invalid flag with value should return parse error", + }, + { + name: "typo in known flag", + args: []string{"--por", "2222", "hostname"}, + description: "Typo in flag name should return parse error (not silently ignored)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset global variables + host = "" + username = "" + port = 22 + command = "" + + err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args) + + // Should return an error for invalid flags + assert.Error(t, err, tt.description) + + // Should not have set host to the invalid flag + assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname") + }) + } +} diff --git a/client/cmd/status.go b/client/cmd/status.go index 6e57ceb89..06460a6a7 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { case yamlFlag: statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder) default: - statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false) + statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false) } if err != nil { diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 78bb0476b..e7b0279e8 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -117,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index 80175f7be..140ba2cb2 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -355,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro if cmd.Flag(serverSSHAllowedFlag).Changed { req.ServerSSHAllowed = &serverSSHAllowed } + if cmd.Flag(enableSSHRootFlag).Changed { + req.EnableSSHRoot = &enableSSHRoot + } + if cmd.Flag(enableSSHSFTPFlag).Changed { + req.EnableSSHSFTP = &enableSSHSFTP + } + if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { + req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward + } + if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { + req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward + } + if cmd.Flag(disableSSHAuthFlag).Changed { + req.DisableSSHAuth = &disableSSHAuth + } + if cmd.Flag(sshJWTCacheTTLFlag).Changed { + sshJWTCacheTTL32 := int32(sshJWTCacheTTL) + req.SshJWTCacheTTL = &sshJWTCacheTTL32 + } if cmd.Flag(interfaceNameFlag).Changed { if err := parseInterfaceName(interfaceName); err != nil { log.Errorf("parse interface name: %v", err) @@ -439,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil ic.ServerSSHAllowed = &serverSSHAllowed } + if cmd.Flag(enableSSHRootFlag).Changed { + ic.EnableSSHRoot = &enableSSHRoot + } + + if cmd.Flag(enableSSHSFTPFlag).Changed { + ic.EnableSSHSFTP = &enableSSHSFTP + } + + if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { + ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward + } + + if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { + ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward + } + + if cmd.Flag(disableSSHAuthFlag).Changed { + ic.DisableSSHAuth = &disableSSHAuth + } + + if cmd.Flag(sshJWTCacheTTLFlag).Changed { + ic.SSHJWTCacheTTL = &sshJWTCacheTTL + } + if cmd.Flag(interfaceNameFlag).Changed { if err := parseInterfaceName(interfaceName); err != nil { return nil, err @@ -539,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte loginRequest.ServerSSHAllowed = &serverSSHAllowed } + if cmd.Flag(enableSSHRootFlag).Changed { + loginRequest.EnableSSHRoot = &enableSSHRoot + } + + if cmd.Flag(enableSSHSFTPFlag).Changed { + loginRequest.EnableSSHSFTP = &enableSSHSFTP + } + + if cmd.Flag(enableSSHLocalPortForwardFlag).Changed { + loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward + } + + if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { + loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward + } + + if cmd.Flag(disableSSHAuthFlag).Changed { + loginRequest.DisableSSHAuth = &disableSSHAuth + } + + if cmd.Flag(sshJWTCacheTTLFlag).Changed { + sshJWTCacheTTL32 := int32(sshJWTCacheTTL) + loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32 + } + if cmd.Flag(disableAutoConnectFlag).Changed { loginRequest.DisableAutoConnect = &autoConnectDisabled } diff --git a/client/embed/embed.go b/client/embed/embed.go index e918235ed..3090ca6a2 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -18,12 +18,16 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + sshcommon "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" ) -var ErrClientAlreadyStarted = errors.New("client already started") -var ErrClientNotStarted = errors.New("client not started") -var ErrConfigNotInitialized = errors.New("config not initialized") +var ( + ErrClientAlreadyStarted = errors.New("client already started") + ErrClientNotStarted = errors.New("client not started") + ErrEngineNotStarted = errors.New("engine not started") + ErrConfigNotInitialized = errors.New("config not initialized") +) // Client manages a netbird embedded client instance. type Client struct { @@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) { // Dial dials a network address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { - c.mu.Lock() - connect := c.connect - if connect == nil { - c.mu.Unlock() - return nil, ErrClientNotStarted - } - c.mu.Unlock() - - engine := connect.Engine() - if engine == nil { - return nil, errors.New("engine not started") + engine, err := c.getEngine() + if err != nil { + return nil, err } nsnet, err := engine.GetNet() @@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e return nsnet.DialContext(ctx, network, address) } +// DialContext dials a network address in the netbird network with context +func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return c.Dial(ctx, network, address) +} + // ListenTCP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenTCP(address string) (net.Listener, error) { @@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client { } } -func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { +// VerifySSHHostKey verifies an SSH host key against stored peer keys. +// Returns nil if the key matches, ErrPeerNotFound if peer is not in network, +// ErrNoStoredKey if peer has no stored key, or an error for verification failures. +func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { + engine, err := c.getEngine() + if err != nil { + return err + } + + storedKey, found := engine.GetPeerSSHKey(peerAddress) + if !found { + return sshcommon.ErrPeerNotFound + } + + return sshcommon.VerifyHostKey(storedKey, key, peerAddress) +} + +// getEngine safely retrieves the engine from the client with proper locking. +// Returns ErrClientNotStarted if the client is not started. +// Returns ErrEngineNotStarted if the engine is not available. +func (c *Client) getEngine() (*internal.Engine, error) { c.mu.Lock() connect := c.connect - if connect == nil { - c.mu.Unlock() - return nil, netip.Addr{}, errors.New("client not started") - } c.mu.Unlock() + if connect == nil { + return nil, ErrClientNotStarted + } + engine := connect.Engine() if engine == nil { - return nil, netip.Addr{}, errors.New("engine not started") + return nil, ErrEngineNotStarted + } + + return engine, nil +} + +func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { + engine, err := c.getEngine() + if err != nil { + return nil, netip.Addr{}, err } addr, err := engine.Address() diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 990630ee4..4e22bde3f 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -35,6 +35,12 @@ const ( ipTCPHeaderMinSize = 40 ) +// serviceKey represents a protocol/port combination for netstack service registry +type serviceKey struct { + protocol gopacket.LayerType + port uint16 +} + const ( // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. EnvDisableConntrack = "NB_DISABLE_CONNTRACK" @@ -59,12 +65,6 @@ const ( var errNatNotSupported = errors.New("nat not supported with userspace firewall") -// serviceKey represents a protocol/port combination for netstack service registry -type serviceKey struct { - protocol gopacket.LayerType - port uint16 -} - // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index c56a078fc..120a9f418 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -1114,3 +1115,138 @@ func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dst return buf.Bytes() } + +func TestShouldForward(t *testing.T) { + // Set up test addresses + wgIP := netip.MustParseAddr("100.10.0.1") + otherIP := netip.MustParseAddr("100.10.0.2") + + // Create test manager with mock interface + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + // Set the mock to return our test WG IP + ifaceMock.AddressFunc = func() wgaddr.Address { + return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)} + } + + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + // Helper to create decoder with TCP packet + createTCPDecoder := func(dstPort uint16) *decoder { + ipv4 := &layers.IPv4{ + Version: 4, + Protocol: layers.IPProtocolTCP, + SrcIP: net.ParseIP("192.168.1.100"), + DstIP: wgIP.AsSlice(), + } + tcp := &layers.TCP{ + SrcPort: 54321, + DstPort: layers.TCPPort(dstPort), + } + + err := tcp.SetNetworkLayerForChecksum(ipv4) + require.NoError(t, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")) + require.NoError(t, err) + + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + + err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded) + require.NoError(t, err) + + return d + } + + tests := []struct { + name string + localForwarding bool + netstack bool + dstIP netip.Addr + serviceRegistered bool + servicePort uint16 + expected bool + description string + }{ + { + name: "no local forwarding", + localForwarding: false, + netstack: true, + dstIP: wgIP, + expected: false, + description: "should never forward when local forwarding disabled", + }, + { + name: "traffic to other local interface", + localForwarding: true, + netstack: false, + dstIP: otherIP, + expected: true, + description: "should forward traffic to our other local interfaces (not NetBird IP)", + }, + { + name: "traffic to NetBird IP, no netstack", + localForwarding: true, + netstack: false, + dstIP: wgIP, + expected: false, + description: "should send to netstack listeners (final return false path)", + }, + { + name: "traffic to our IP, netstack mode, no service", + localForwarding: true, + netstack: true, + dstIP: wgIP, + expected: true, + description: "should forward when in netstack mode with no matching service", + }, + { + name: "traffic to our IP, netstack mode, with service", + localForwarding: true, + netstack: true, + dstIP: wgIP, + serviceRegistered: true, + servicePort: 22, + expected: false, + description: "should send to netstack listeners when service is registered", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Configure manager + manager.localForwarding = tt.localForwarding + manager.netstack = tt.netstack + + // Register service if needed + if tt.serviceRegistered { + manager.RegisterNetstackService(nftypes.TCP, tt.servicePort) + defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort) + } + + // Create decoder for the test + decoder := createTCPDecoder(tt.servicePort) + if !tt.serviceRegistered { + decoder = createTCPDecoder(8080) // Use non-registered port + } + + // Test the method + result := manager.shouldForward(decoder, tt.dstIP) + require.Equal(t, tt.expected, result, tt.description) + }) + } +} diff --git a/client/firewall/uspfilter/nat_stateful_test.go b/client/firewall/uspfilter/nat_stateful_test.go new file mode 100644 index 000000000..21c6da06e --- /dev/null +++ b/client/firewall/uspfilter/nat_stateful_test.go @@ -0,0 +1,85 @@ +package uspfilter + +import ( + "net/netip" + "testing" + + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +// TestPortDNATBasic tests basic port DNAT functionality +func TestPortDNATBasic(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + // Define peer IPs + peerA := netip.MustParseAddr("100.10.0.50") + peerB := netip.MustParseAddr("100.10.0.51") + + // Add SSH port redirection rule for peer B (the target) + err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022) + require.NoError(t, err) + + // Scenario: Peer A connects to Peer B on port 22 (should get NAT) + packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22) + d := parsePacket(t, packetAtoB) + translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB) + require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)") + + // Verify port was translated to 22022 + d = parsePacket(t, packetAtoB) + require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022") + + // Scenario: Return traffic from Peer B to Peer A should NOT be translated + // (prevents double NAT - original port stored in conntrack) + returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321) + d2 := parsePacket(t, returnPacket) + translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA) + require.False(t, translatedReturn, "Return traffic from same IP should not be translated") +} + +// TestPortDNATMultipleRules tests multiple port DNAT rules +func TestPortDNATMultipleRules(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + // Define peer IPs + peerA := netip.MustParseAddr("100.10.0.50") + peerB := netip.MustParseAddr("100.10.0.51") + + // Add SSH port redirection rules for both peers + err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022) + require.NoError(t, err) + err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022) + require.NoError(t, err) + + // Test traffic to peer B gets translated + packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22) + d1 := parsePacket(t, packetToB) + translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB) + require.True(t, translatedToB, "Traffic to peer B should be translated") + d1 = parsePacket(t, packetToB) + require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022") + + // Test traffic to peer A gets translated + packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22) + d2 := parsePacket(t, packetToA) + translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA) + require.True(t, translatedToA, "Traffic to peer A should be translated") + d2 = parsePacket(t, packetToA) + require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022") +} diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 965decc73..dd6f9479a 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -17,7 +17,6 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" - "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) @@ -83,22 +82,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { rules := networkMap.FirewallRules - enableSSH := networkMap.PeerConfig != nil && - networkMap.PeerConfig.SshConfig != nil && - networkMap.PeerConfig.SshConfig.SshEnabled - - // If SSH enabled, add default firewall rule which accepts connection to any peer - // in the network by SSH (TCP port defined by ssh.DefaultSSHPort). - if enableSSH { - rules = append(rules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: strconv.Itoa(ssh.DefaultSSHPort), - }) - } - // if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag // we have old version of management without rules handling, we should allow all traffic if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty { diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 638245bf7..4bc0fd800 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -272,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) { }) } } - -func TestDefaultManagerEnableSSHRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - PeerConfig: &mgmProto.PeerConfig{ - SshConfig: &mgmProto.SSHConfig{ - SshEnabled: true, - }, - }, - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ifaceMock := mocks.NewMockIFaceMapper(ctrl) - ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() - ifaceMock.EXPECT().SetFilter(gomock.Any()) - network := netip.MustParsePrefix("172.0.0.1/32") - - ifaceMock.EXPECT().Name().Return("lo").AnyTimes() - ifaceMock.EXPECT().Address().Return(wgaddr.Address{ - IP: network.Addr(), - Network: network, - }).AnyTimes() - ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) - require.NoError(t, err) - defer func() { - err = fw.Close(nil) - require.NoError(t, err) - }() - - acl := NewDefaultManager(fw) - - acl.ApplyFiltering(networkMap, false) - - expectedRules := 3 - if fw.IsStateful() { - expectedRules = 3 // 2 inbound rules + SSH rule - } - assert.Equal(t, expectedRules, len(acl.peerRulesPairs)) -} diff --git a/client/internal/connect.go b/client/internal/connect.go index bb7c2b38b..6ad5f264b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -416,20 +416,25 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf nm = *config.NetworkMonitor } engineConf := &EngineConfig{ - WgIfaceName: config.WgIface, - WgAddr: peerConfig.Address, - IFaceBlackList: config.IFaceBlackList, - DisableIPv6Discovery: config.DisableIPv6Discovery, - WgPrivateKey: key, - WgPort: config.WgPort, - NetworkMonitor: nm, - SSHKey: []byte(config.SSHKey), - NATExternalIPs: config.NATExternalIPs, - CustomDNSAddress: config.CustomDNSAddress, - RosenpassEnabled: config.RosenpassEnabled, - RosenpassPermissive: config.RosenpassPermissive, - ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), - DNSRouteInterval: config.DNSRouteInterval, + WgIfaceName: config.WgIface, + WgAddr: peerConfig.Address, + IFaceBlackList: config.IFaceBlackList, + DisableIPv6Discovery: config.DisableIPv6Discovery, + WgPrivateKey: key, + WgPort: config.WgPort, + NetworkMonitor: nm, + SSHKey: []byte(config.SSHKey), + NATExternalIPs: config.NATExternalIPs, + CustomDNSAddress: config.CustomDNSAddress, + RosenpassEnabled: config.RosenpassEnabled, + RosenpassPermissive: config.RosenpassPermissive, + ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), + EnableSSHRoot: config.EnableSSHRoot, + EnableSSHSFTP: config.EnableSSHSFTP, + EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding, + EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding, + DisableSSHAuth: config.DisableSSHAuth, + DNSRouteInterval: config.DNSRouteInterval, DisableClientRoutes: config.DisableClientRoutes, DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound, @@ -515,6 +520,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config.BlockLANAccess, config.BlockInbound, config.LazyConnectionEnabled, + config.EnableSSHRoot, + config.EnableSSHSFTP, + config.EnableSSHLocalPortForwarding, + config.EnableSSHRemotePortForwarding, + config.DisableSSHAuth, ) loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) if err != nil { diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index fbec29ce3..58977b884 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -453,6 +453,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) if g.internalConfig.ServerSSHAllowed != nil { configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed)) } + if g.internalConfig.EnableSSHRoot != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot)) + } + if g.internalConfig.EnableSSHSFTP != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP)) + } + if g.internalConfig.EnableSSHLocalPortForwarding != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding)) + } + if g.internalConfig.EnableSSHRemotePortForwarding != nil { + configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) + } configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) diff --git a/client/internal/engine.go b/client/internal/engine.go index ebc05c453..1deb3d3cf 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -9,7 +9,6 @@ import ( "net/netip" "net/url" "os" - "reflect" "runtime" "slices" "sort" @@ -30,7 +29,6 @@ import ( firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" - nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" @@ -51,10 +49,10 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" cProto "github.com/netbirdio/netbird/client/proto" + sshconfig "github.com/netbirdio/netbird/client/ssh/config" "github.com/netbirdio/netbird/shared/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" - nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -115,7 +113,12 @@ type EngineConfig struct { RosenpassEnabled bool RosenpassPermissive bool - ServerSSHAllowed bool + ServerSSHAllowed bool + EnableSSHRoot *bool + EnableSSHSFTP *bool + EnableSSHLocalPortForwarding *bool + EnableSSHRemotePortForwarding *bool + DisableSSHAuth *bool DNSRouteInterval time.Duration @@ -148,8 +151,6 @@ type Engine struct { // syncMsgMux is used to guarantee sequential Management Service message processing syncMsgMux *sync.Mutex - // sshMux protects sshServer field access - sshMux sync.Mutex config *EngineConfig mobileDep MobileDependency @@ -175,8 +176,7 @@ type Engine struct { networkMonitor *networkmonitor.NetworkMonitor - sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) - sshServer nbssh.Server + sshServer sshServer statusRecorder *peer.Status @@ -246,7 +246,6 @@ func NewEngine( STUNs: []*stun.URI{}, TURNs: []*stun.URI{}, networkSerial: 0, - sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), @@ -268,6 +267,7 @@ func NewEngine( path = mobileDep.StateFilePath } engine.stateManager = statemanager.New(path) + engine.stateManager.RegisterState(&sshconfig.ShutdownState{}) log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) return engine @@ -292,6 +292,12 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") + if err := e.stopSSHServer(); err != nil { + log.Warnf("failed to stop SSH server: %v", err) + } + + e.cleanupSSHConfig() + // stop/restore DNS first so dbus and friends don't complain because of a missing interface e.stopDNSServer() @@ -703,16 +709,10 @@ func (e *Engine) removeAllPeers() error { return nil } -// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server +// removePeer closes an existing peer connection and removes a peer func (e *Engine) removePeer(peerKey string) error { log.Debugf("removing peer from engine %s", peerKey) - e.sshMux.Lock() - if !isNil(e.sshServer) { - e.sshServer.RemoveAuthorizedKey(peerKey) - } - e.sshMux.Unlock() - e.connMgr.RemovePeerConn(peerKey) err := e.statusRecorder.RemovePeer(peerKey) @@ -884,6 +884,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { e.config.BlockLANAccess, e.config.BlockInbound, e.config.LazyConnectionEnabled, + e.config.EnableSSHRoot, + e.config.EnableSSHSFTP, + e.config.EnableSSHLocalPortForwarding, + e.config.EnableSSHRemotePortForwarding, + e.config.DisableSSHAuth, ) if err := e.mgmClient.SyncMeta(info); err != nil { @@ -893,74 +898,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { return nil } -func isNil(server nbssh.Server) bool { - return server == nil || reflect.ValueOf(server).IsNil() -} - -func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { - if e.config.BlockInbound { - log.Infof("SSH server is disabled because inbound connections are blocked") - return nil - } - - if !e.config.ServerSSHAllowed { - log.Info("SSH server is not enabled") - return nil - } - - if sshConf.GetSshEnabled() { - if runtime.GOOS == "windows" { - log.Warnf("running SSH server on %s is not supported", runtime.GOOS) - return nil - } - e.sshMux.Lock() - // start SSH server if it wasn't running - if isNil(e.sshServer) { - listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort) - if nbnetstack.IsEnabled() { - listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort) - } - // nil sshServer means it has not yet been started - server, err := e.sshServerFunc(e.config.SSHKey, listenAddr) - if err != nil { - e.sshMux.Unlock() - return fmt.Errorf("create ssh server: %w", err) - } - - e.sshServer = server - e.sshMux.Unlock() - - go func() { - // blocking - err = server.Start() - if err != nil { - // will throw error when we stop it even if it is a graceful stop - log.Debugf("stopped SSH server with error %v", err) - } - e.sshMux.Lock() - e.sshServer = nil - e.sshMux.Unlock() - log.Infof("stopped SSH server") - }() - } else { - e.sshMux.Unlock() - log.Debugf("SSH server is already running") - } - } else { - e.sshMux.Lock() - if !isNil(e.sshServer) { - // Disable SSH server request, so stop it if it was running - err := e.sshServer.Stop() - if err != nil { - log.Warnf("failed to stop SSH server %v", err) - } - e.sshServer = nil - } - e.sshMux.Unlock() - } - return nil -} - func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { if e.wgInterface == nil { return errors.New("wireguard interface is not initialized") @@ -973,8 +910,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { } if conf.GetSshConfig() != nil { - err := e.updateSSH(conf.GetSshConfig()) - if err != nil { + if err := e.updateSSH(conf.GetSshConfig()); err != nil { log.Warnf("failed handling SSH server setup: %v", err) } } @@ -1012,6 +948,11 @@ func (e *Engine) receiveManagementEvents() { e.config.BlockLANAccess, e.config.BlockInbound, e.config.LazyConnectionEnabled, + e.config.EnableSSHRoot, + e.config.EnableSSHSFTP, + e.config.EnableSSHLocalPortForwarding, + e.config.EnableSSHRemotePortForwarding, + e.config.DisableSSHAuth, ) err = e.mgmClient.Sync(e.ctx, info, e.handleSync) @@ -1170,19 +1111,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.statusRecorder.FinishPeerListModifications() - // update SSHServer by adding remote peer SSH keys - e.sshMux.Lock() - if !isNil(e.sshServer) { - for _, config := range networkMap.GetRemotePeers() { - if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil { - err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey())) - if err != nil { - log.Warnf("failed adding authorized key to SSH DefaultServer %v", err) - } - } - } + e.updatePeerSSHHostKeys(networkMap.GetRemotePeers()) + + if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil { + log.Warnf("failed to update SSH client config: %v", err) } - e.sshMux.Unlock() } // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store @@ -1544,15 +1477,6 @@ func (e *Engine) close() { e.statusRecorder.SetWgIface(nil) } - e.sshMux.Lock() - if !isNil(e.sshServer) { - err := e.sshServer.Stop() - if err != nil { - log.Warnf("failed stopping the SSH server: %v", err) - } - } - e.sshMux.Unlock() - if e.firewall != nil { err := e.firewall.Close(e.stateManager) if err != nil { @@ -1583,6 +1507,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err e.config.BlockLANAccess, e.config.BlockInbound, e.config.LazyConnectionEnabled, + e.config.EnableSSHRoot, + e.config.EnableSSHSFTP, + e.config.EnableSSHLocalPortForwarding, + e.config.EnableSSHRemotePortForwarding, + e.config.DisableSSHAuth, ) netMap, err := e.mgmClient.GetNetworkMap(info) diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go new file mode 100644 index 000000000..861b3d6d2 --- /dev/null +++ b/client/internal/engine_ssh.go @@ -0,0 +1,355 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "net/netip" + "strings" + + log "github.com/sirupsen/logrus" + + firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + sshconfig "github.com/netbirdio/netbird/client/ssh/config" + sshserver "github.com/netbirdio/netbird/client/ssh/server" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" +) + +type sshServer interface { + Start(ctx context.Context, addr netip.AddrPort) error + Stop() error + GetStatus() (bool, []sshserver.SessionInfo) +} + +func (e *Engine) setupSSHPortRedirection() error { + if e.firewall == nil || e.wgInterface == nil { + return nil + } + + localAddr := e.wgInterface.Address().IP + if !localAddr.IsValid() { + return errors.New("invalid local NetBird address") + } + + if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil { + return fmt.Errorf("add SSH port redirection: %w", err) + } + log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr) + + return nil +} + +func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { + if e.config.BlockInbound { + log.Info("SSH server is disabled because inbound connections are blocked") + return e.stopSSHServer() + } + + if !e.config.ServerSSHAllowed { + log.Info("SSH server is disabled in config") + return e.stopSSHServer() + } + + if !sshConf.GetSshEnabled() { + if e.config.ServerSSHAllowed { + log.Info("SSH server is locally allowed but disabled by management server") + } + return e.stopSSHServer() + } + + if e.sshServer != nil { + log.Debug("SSH server is already running") + return nil + } + + if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth { + log.Info("starting SSH server without JWT authentication (authentication disabled by config)") + return e.startSSHServer(nil) + } + + if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil { + jwtConfig := &sshserver.JWTConfig{ + Issuer: protoJWT.GetIssuer(), + Audience: protoJWT.GetAudience(), + KeysLocation: protoJWT.GetKeysLocation(), + MaxTokenAge: protoJWT.GetMaxTokenAge(), + } + + return e.startSSHServer(jwtConfig) + } + + return errors.New("SSH server requires valid JWT configuration") +} + +// updateSSHClientConfig updates the SSH client configuration with peer information +func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error { + peerInfo := e.extractPeerSSHInfo(remotePeers) + if len(peerInfo) == 0 { + log.Debug("no SSH-enabled peers found, skipping SSH config update") + return nil + } + + configMgr := sshconfig.New() + if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil { + log.Warnf("failed to update SSH client config: %v", err) + return nil // Don't fail engine startup on SSH config issues + } + + log.Debugf("updated SSH client config with %d peers", len(peerInfo)) + + if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{ + SSHConfigDir: configMgr.GetSSHConfigDir(), + SSHConfigFile: configMgr.GetSSHConfigFile(), + }); err != nil { + log.Warnf("failed to update SSH config state: %v", err) + } + + return nil +} + +// extractPeerSSHInfo extracts SSH information from peer configurations +func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo { + var peerInfo []sshconfig.PeerSSHInfo + + for _, peerConfig := range remotePeers { + if peerConfig.GetSshConfig() == nil { + continue + } + + sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey() + if len(sshPubKeyBytes) == 0 { + continue + } + + peerIP := e.extractPeerIP(peerConfig) + hostname := e.extractHostname(peerConfig) + + peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{ + Hostname: hostname, + IP: peerIP, + FQDN: peerConfig.GetFqdn(), + }) + } + + return peerInfo +} + +// extractPeerIP extracts IP address from peer's allowed IPs +func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string { + if len(peerConfig.GetAllowedIps()) == 0 { + return "" + } + + if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil { + return prefix.Addr().String() + } + return "" +} + +// extractHostname extracts short hostname from FQDN +func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string { + fqdn := peerConfig.GetFqdn() + if fqdn == "" { + return "" + } + + parts := strings.Split(fqdn, ".") + if len(parts) > 0 && parts[0] != "" { + return parts[0] + } + return "" +} + +// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access +func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) { + for _, peerConfig := range remotePeers { + if peerConfig.GetSshConfig() == nil { + continue + } + + sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey() + if len(sshPubKeyBytes) == 0 { + continue + } + + if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil { + log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err) + } + } + + log.Debugf("updated peer SSH host keys for daemon API access") +} + +// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN +func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) { + e.syncMsgMux.Lock() + statusRecorder := e.statusRecorder + e.syncMsgMux.Unlock() + + if statusRecorder == nil { + return nil, false + } + + fullStatus := statusRecorder.GetFullStatus() + for _, peerState := range fullStatus.Peers { + if peerState.IP == peerAddress || peerState.FQDN == peerAddress { + if len(peerState.SSHHostKey) > 0 { + return peerState.SSHHostKey, true + } + return nil, false + } + } + + return nil, false +} + +// cleanupSSHConfig removes NetBird SSH client configuration on shutdown +func (e *Engine) cleanupSSHConfig() { + configMgr := sshconfig.New() + + if err := configMgr.RemoveSSHClientConfig(); err != nil { + log.Warnf("failed to remove SSH client config: %v", err) + } else { + log.Debugf("SSH client config cleanup completed") + } +} + +// startSSHServer initializes and starts the SSH server with proper configuration. +func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error { + if e.wgInterface == nil { + return errors.New("wg interface not initialized") + } + + serverConfig := &sshserver.Config{ + HostKeyPEM: e.config.SSHKey, + JWT: jwtConfig, + } + server := sshserver.New(serverConfig) + + wgAddr := e.wgInterface.Address() + server.SetNetworkValidation(wgAddr) + + netbirdIP := wgAddr.IP + listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort) + + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + server.SetNetstackNet(netstackNet) + } + + e.configureSSHServer(server) + + if err := server.Start(e.ctx, listenAddr); err != nil { + return fmt.Errorf("start SSH server: %w", err) + } + + e.sshServer = server + + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort) + log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort) + } + } + + if err := e.setupSSHPortRedirection(); err != nil { + log.Warnf("failed to setup SSH port redirection: %v", err) + } + + return nil +} + +// configureSSHServer applies SSH configuration options to the server. +func (e *Engine) configureSSHServer(server *sshserver.Server) { + if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot { + server.SetAllowRootLogin(true) + log.Info("SSH root login enabled") + } else { + server.SetAllowRootLogin(false) + log.Info("SSH root login disabled (default)") + } + + if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP { + server.SetAllowSFTP(true) + log.Info("SSH SFTP subsystem enabled") + } else { + server.SetAllowSFTP(false) + log.Info("SSH SFTP subsystem disabled (default)") + } + + if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding { + server.SetAllowLocalPortForwarding(true) + log.Info("SSH local port forwarding enabled") + } else { + server.SetAllowLocalPortForwarding(false) + log.Info("SSH local port forwarding disabled (default)") + } + + if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding { + server.SetAllowRemotePortForwarding(true) + log.Info("SSH remote port forwarding enabled") + } else { + server.SetAllowRemotePortForwarding(false) + log.Info("SSH remote port forwarding disabled (default)") + } +} + +func (e *Engine) cleanupSSHPortRedirection() error { + if e.firewall == nil || e.wgInterface == nil { + return nil + } + + localAddr := e.wgInterface.Address().IP + if !localAddr.IsValid() { + return errors.New("invalid local NetBird address") + } + + if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil { + return fmt.Errorf("remove SSH port redirection: %w", err) + } + log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr) + + return nil +} + +func (e *Engine) stopSSHServer() error { + if e.sshServer == nil { + return nil + } + + if err := e.cleanupSSHPortRedirection(); err != nil { + log.Warnf("failed to cleanup SSH port redirection: %v", err) + } + + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort) + log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort) + } + } + + log.Info("stopping SSH server") + err := e.sshServer.Stop() + e.sshServer = nil + if err != nil { + return fmt.Errorf("stop: %w", err) + } + return nil +} + +// GetSSHServerStatus returns the SSH server status and active sessions +func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) { + e.syncMsgMux.Lock() + sshServer := e.sshServer + e.syncMsgMux.Unlock() + + if sshServer == nil { + return false, nil + } + + return sshServer.GetStatus() +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d15a07f9d..3b7ff0eba 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -14,7 +14,6 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" - "github.com/netbirdio/netbird/client/internal/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,7 +24,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -46,7 +48,7 @@ import ( icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/routemanager" - "github.com/netbirdio/netbird/client/ssh" + nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" @@ -214,11 +216,13 @@ func TestMain(m *testing.M) { } func TestEngine_SSH(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("skipping TestEngine_SSH") + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return } - key, err := wgtypes.GeneratePrivateKey() + sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) if err != nil { t.Fatal(err) return @@ -240,6 +244,7 @@ func TestEngine_SSH(t *testing.T) { WgPort: 33100, ServerSSHAllowed: true, MTU: iface.DefaultMTU, + SSHKey: sshKey, }, MobileDependency{}, peer.NewRecorder("https://mgm"), @@ -250,35 +255,8 @@ func TestEngine_SSH(t *testing.T) { UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } - var sshKeysAdded []string - var sshPeersRemoved []string - - sshCtx, cancel := context.WithCancel(context.Background()) - - engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) { - return &ssh.MockServer{ - Ctx: sshCtx, - StopFunc: func() error { - cancel() - return nil - }, - StartFunc: func() error { - <-ctx.Done() - return ctx.Err() - }, - AddAuthorizedKeyFunc: func(peer, newKey string) error { - sshKeysAdded = append(sshKeysAdded, newKey) - return nil - }, - RemoveAuthorizedKeyFunc: func(peer string) { - sshPeersRemoved = append(sshPeersRemoved, peer) - }, - }, nil - } err = engine.Start(nil, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer func() { err := engine.Stop() @@ -304,9 +282,7 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Nil(t, engine.sshServer) @@ -314,19 +290,24 @@ func TestEngine_SSH(t *testing.T) { networkMap = &mgmtProto.NetworkMap{ Serial: 7, PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24", - SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}}, + SshConfig: &mgmtProto.SSHConfig{ + SshEnabled: true, + JwtConfig: &mgmtProto.JWTConfig{ + Issuer: "test-issuer", + Audience: "test-audience", + KeysLocation: "test-keys", + MaxTokenAge: 3600, + }, + }}, RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH}, RemotePeersIsEmpty: false, } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) - assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ") // now remove peer networkMap = &mgmtProto.NetworkMap{ @@ -336,13 +317,10 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) - assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=") // now disable SSH server networkMap = &mgmtProto.NetworkMap{ @@ -354,12 +332,70 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Nil(t, engine.sshServer) +} +func TestEngine_SSHUpdateLogic(t *testing.T) { + // Test that SSH server start/stop logic works based on config + engine := &Engine{ + config: &EngineConfig{ + ServerSSHAllowed: false, // Start with SSH disabled + }, + syncMsgMux: &sync.Mutex{}, + } + + // Test SSH disabled config + sshConfig := &mgmtProto.SSHConfig{SshEnabled: false} + err := engine.updateSSH(sshConfig) + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) + + // Test inbound blocked + engine.config.BlockInbound = true + err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true}) + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) + engine.config.BlockInbound = false + + // Test with server SSH not allowed + err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true}) + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) +} + +func TestEngine_SSHServerConsistency(t *testing.T) { + + t.Run("server set only on successful creation", func(t *testing.T) { + engine := &Engine{ + config: &EngineConfig{ + ServerSSHAllowed: true, + SSHKey: []byte("test-key"), + }, + syncMsgMux: &sync.Mutex{}, + } + + engine.wgInterface = nil + + err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true}) + + assert.Error(t, err) + assert.Nil(t, engine.sshServer) + }) + + t.Run("cleanup handles nil gracefully", func(t *testing.T) { + engine := &Engine{ + config: &EngineConfig{ + ServerSSHAllowed: false, + }, + syncMsgMux: &sync.Mutex{}, + } + + err := engine.stopSSHServer() + assert.NoError(t, err) + assert.Nil(t, engine.sshServer) + }) } func TestEngine_UpdateNetworkMap(t *testing.T) { @@ -1589,7 +1625,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err } diff --git a/client/internal/login.go b/client/internal/login.go index 257e3c3ac..f528783ef 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte config.BlockLANAccess, config.BlockInbound, config.LazyConnectionEnabled, + config.EnableSSHRoot, + config.EnableSSHSFTP, + config.EnableSSHLocalPortForwarding, + config.EnableSSHRemotePortForwarding, + config.DisableSSHAuth, ) loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) return serverKey, loginResp, err @@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm. config.BlockLANAccess, config.BlockInbound, config.LazyConnectionEnabled, + config.EnableSSHRoot, + config.EnableSSHSFTP, + config.EnableSSHLocalPortForwarding, + config.EnableSSHRemotePortForwarding, + config.DisableSSHAuth, ) loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) if err != nil { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 68afe986a..426c31e1a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) { } }() - if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { + if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { return false } diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go index 32a458d00..7f500c410 100644 --- a/client/internal/peer/env.go +++ b/client/internal/peer/env.go @@ -2,6 +2,7 @@ package peer import ( "os" + "runtime" "strings" ) @@ -10,5 +11,8 @@ const ( ) func isForceRelayed() bool { + if runtime.GOOS == "js" { + return true + } return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 239cce7e0..76f4f523c 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -21,9 +21,9 @@ import ( "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" relayClient "github.com/netbirdio/netbird/shared/relay/client" - "github.com/netbirdio/netbird/route" ) const eventQueueSize = 10 @@ -67,6 +67,7 @@ type State struct { BytesRx int64 Latency time.Duration RosenpassEnabled bool + SSHHostKey []byte routes map[string]struct{} } @@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error { return nil } +// UpdatePeerSSHHostKey updates peer's SSH host key +func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[peerPubKey] + if !ok { + return errors.New("peer doesn't exist") + } + + peerState.SSHHostKey = sshHostKey + d.peers[peerPubKey] = peerState + + return nil +} + // FinishPeerListModifications this event invoke the notification func (d *Status) FinishPeerListModifications() { d.mux.Lock() diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index f03822089..8f467a214 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -44,24 +44,30 @@ var DefaultInterfaceBlacklist = []string{ // ConfigInput carries configuration changes to the client type ConfigInput struct { - ManagementURL string - AdminURL string - ConfigPath string - StateFilePath string - PreSharedKey *string - ServerSSHAllowed *bool - NATExternalIPs []string - CustomDNSAddress []byte - RosenpassEnabled *bool - RosenpassPermissive *bool - InterfaceName *string - WireguardPort *int - NetworkMonitor *bool - DisableAutoConnect *bool - ExtraIFaceBlackList []string - DNSRouteInterval *time.Duration - ClientCertPath string - ClientCertKeyPath string + ManagementURL string + AdminURL string + ConfigPath string + StateFilePath string + PreSharedKey *string + ServerSSHAllowed *bool + EnableSSHRoot *bool + EnableSSHSFTP *bool + EnableSSHLocalPortForwarding *bool + EnableSSHRemotePortForwarding *bool + DisableSSHAuth *bool + SSHJWTCacheTTL *int + NATExternalIPs []string + CustomDNSAddress []byte + RosenpassEnabled *bool + RosenpassPermissive *bool + InterfaceName *string + WireguardPort *int + NetworkMonitor *bool + DisableAutoConnect *bool + ExtraIFaceBlackList []string + DNSRouteInterval *time.Duration + ClientCertPath string + ClientCertKeyPath string DisableClientRoutes *bool DisableServerRoutes *bool @@ -82,18 +88,24 @@ type ConfigInput struct { // Config Configuration type type Config struct { // Wireguard private key of local peer - PrivateKey string - PreSharedKey string - ManagementURL *url.URL - AdminURL *url.URL - WgIface string - WgPort int - NetworkMonitor *bool - IFaceBlackList []string - DisableIPv6Discovery bool - RosenpassEnabled bool - RosenpassPermissive bool - ServerSSHAllowed *bool + PrivateKey string + PreSharedKey string + ManagementURL *url.URL + AdminURL *url.URL + WgIface string + WgPort int + NetworkMonitor *bool + IFaceBlackList []string + DisableIPv6Discovery bool + RosenpassEnabled bool + RosenpassPermissive bool + ServerSSHAllowed *bool + EnableSSHRoot *bool + EnableSSHSFTP *bool + EnableSSHLocalPortForwarding *bool + EnableSSHRemotePortForwarding *bool + DisableSSHAuth *bool + SSHJWTCacheTTL *int DisableClientRoutes bool DisableServerRoutes bool @@ -376,6 +388,62 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } + if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot { + if *input.EnableSSHRoot { + log.Infof("enabling SSH root login") + } else { + log.Infof("disabling SSH root login") + } + config.EnableSSHRoot = input.EnableSSHRoot + updated = true + } + + if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP { + if *input.EnableSSHSFTP { + log.Infof("enabling SSH SFTP subsystem") + } else { + log.Infof("disabling SSH SFTP subsystem") + } + config.EnableSSHSFTP = input.EnableSSHSFTP + updated = true + } + + if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding { + if *input.EnableSSHLocalPortForwarding { + log.Infof("enabling SSH local port forwarding") + } else { + log.Infof("disabling SSH local port forwarding") + } + config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding + updated = true + } + + if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding { + if *input.EnableSSHRemotePortForwarding { + log.Infof("enabling SSH remote port forwarding") + } else { + log.Infof("disabling SSH remote port forwarding") + } + config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding + updated = true + } + + if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth { + if *input.DisableSSHAuth { + log.Infof("disabling SSH authentication") + } else { + log.Infof("enabling SSH authentication") + } + config.DisableSSHAuth = input.DisableSSHAuth + updated = true + } + + if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL { + log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL) + config.SSHJWTCacheTTL = input.SSHJWTCacheTTL + updated = true + } + if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval { log.Infof("updating DNS route interval to %s (old value %s)", input.DNSRouteInterval.String(), config.DNSRouteInterval.String()) diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index 90bde7707..ab13cf389 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -193,10 +193,10 @@ func TestWireguardPortZeroExplicit(t *testing.T) { func TestWireguardPortDefaultVsExplicit(t *testing.T) { tests := []struct { - name string - wireguardPort *int - expectedPort int - description string + name string + wireguardPort *int + expectedPort int + description string }{ { name: "no port specified uses default", diff --git a/client/internal/profilemanager/profilemanager.go b/client/internal/profilemanager/profilemanager.go index fe0afae2b..c87f521cb 100644 --- a/client/internal/profilemanager/profilemanager.go +++ b/client/internal/profilemanager/profilemanager.go @@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error { return nil } + +// GetLoginHint retrieves the email from the active profile to use as login_hint. +func GetLoginHint() string { + pm := NewProfileManager() + activeProf, err := pm.GetActiveProfile() + if err != nil { + log.Debugf("failed to get active profile for login hint: %v", err) + return "" + } + + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + return "" + } + + return profileState.Email +} diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 587e05c74..8d1398a7a 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -18,8 +18,8 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) const ( diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 26cf758d9..2baa0e668 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" - nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/client" @@ -39,6 +38,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/client/net" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" "github.com/netbirdio/netbird/version" diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index fa1c89aab..b0d377c21 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // ConnectionListener export internal Listener for mobile diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 02f09b08a..7b9ae25f7 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -137,7 +137,7 @@ func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Severity.Descriptor instead. func (SystemEvent_Severity) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49, 0} + return file_daemon_proto_rawDescGZIP(), []int{51, 0} } type SystemEvent_Category int32 @@ -192,7 +192,7 @@ func (x SystemEvent_Category) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Category.Descriptor instead. func (SystemEvent_Category) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49, 1} + return file_daemon_proto_rawDescGZIP(), []int{51, 1} } type EmptyRequest struct { @@ -280,9 +280,15 @@ type LoginRequest struct { Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` // hint is used to pre-fill the email/username field during SSO authentication - Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + EnableSSHRoot *bool `protobuf:"varint,34,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP *bool `protobuf:"varint,35,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` + SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *LoginRequest) Reset() { @@ -547,6 +553,48 @@ func (x *LoginRequest) GetHint() string { return "" } +func (x *LoginRequest) GetEnableSSHRoot() bool { + if x != nil && x.EnableSSHRoot != nil { + return *x.EnableSSHRoot + } + return false +} + +func (x *LoginRequest) GetEnableSSHSFTP() bool { + if x != nil && x.EnableSSHSFTP != nil { + return *x.EnableSSHSFTP + } + return false +} + +func (x *LoginRequest) GetEnableSSHLocalPortForwarding() bool { + if x != nil && x.EnableSSHLocalPortForwarding != nil { + return *x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *LoginRequest) GetEnableSSHRemotePortForwarding() bool { + if x != nil && x.EnableSSHRemotePortForwarding != nil { + return *x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *LoginRequest) GetDisableSSHAuth() bool { + if x != nil && x.DisableSSHAuth != nil { + return *x.DisableSSHAuth + } + return false +} + +func (x *LoginRequest) GetSshJWTCacheTTL() int32 { + if x != nil && x.SshJWTCacheTTL != nil { + return *x.SshJWTCacheTTL + } + return 0 +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -1057,24 +1105,30 @@ type GetConfigResponse struct { // preSharedKey settings value. PreSharedKey string `protobuf:"bytes,4,opt,name=preSharedKey,proto3" json:"preSharedKey,omitempty"` // adminURL settings value. - AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` - InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` - WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` - Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"` - DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` - ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` - RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"` - LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` - BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` - NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"` - DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"` - DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"` - DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"` - BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` + InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` + WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` + Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"` + DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` + ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` + RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` + DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"` + LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` + NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"` + DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"` + DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"` + DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"` + BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"` + EnableSSHRoot bool `protobuf:"varint,21,opt,name=enableSSHRoot,proto3" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` + SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *GetConfigResponse) Reset() { @@ -1247,6 +1301,48 @@ func (x *GetConfigResponse) GetBlockLanAccess() bool { return false } +func (x *GetConfigResponse) GetEnableSSHRoot() bool { + if x != nil { + return x.EnableSSHRoot + } + return false +} + +func (x *GetConfigResponse) GetEnableSSHSFTP() bool { + if x != nil { + return x.EnableSSHSFTP + } + return false +} + +func (x *GetConfigResponse) GetEnableSSHLocalPortForwarding() bool { + if x != nil { + return x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *GetConfigResponse) GetEnableSSHRemotePortForwarding() bool { + if x != nil { + return x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *GetConfigResponse) GetDisableSSHAuth() bool { + if x != nil { + return x.DisableSSHAuth + } + return false +} + +func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 { + if x != nil { + return x.SshJWTCacheTTL + } + return 0 +} + // PeerState contains the latest state of a peer type PeerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1267,6 +1363,7 @@ type PeerState struct { Networks []string `protobuf:"bytes,16,rep,name=networks,proto3" json:"networks,omitempty"` Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` + SshHostKey []byte `protobuf:"bytes,19,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1420,6 +1517,13 @@ func (x *PeerState) GetRelayAddress() string { return "" } +func (x *PeerState) GetSshHostKey() []byte { + if x != nil { + return x.SshHostKey + } + return nil +} + // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1764,6 +1868,128 @@ func (x *NSGroupState) GetError() string { return "" } +// SSHSessionInfo contains information about an active SSH session +type SSHSessionInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"` + Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"` + JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SSHSessionInfo) Reset() { + *x = SSHSessionInfo{} + mi := &file_daemon_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SSHSessionInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SSHSessionInfo) ProtoMessage() {} + +func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[19] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SSHSessionInfo.ProtoReflect.Descriptor instead. +func (*SSHSessionInfo) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{19} +} + +func (x *SSHSessionInfo) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *SSHSessionInfo) GetRemoteAddress() string { + if x != nil { + return x.RemoteAddress + } + return "" +} + +func (x *SSHSessionInfo) GetCommand() string { + if x != nil { + return x.Command + } + return "" +} + +func (x *SSHSessionInfo) GetJwtUsername() string { + if x != nil { + return x.JwtUsername + } + return "" +} + +// SSHServerState contains the latest state of the SSH server +type SSHServerState struct { + state protoimpl.MessageState `protogen:"open.v1"` + Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` + Sessions []*SSHSessionInfo `protobuf:"bytes,2,rep,name=sessions,proto3" json:"sessions,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SSHServerState) Reset() { + *x = SSHServerState{} + mi := &file_daemon_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SSHServerState) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SSHServerState) ProtoMessage() {} + +func (x *SSHServerState) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[20] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SSHServerState.ProtoReflect.Descriptor instead. +func (*SSHServerState) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{20} +} + +func (x *SSHServerState) GetEnabled() bool { + if x != nil { + return x.Enabled + } + return false +} + +func (x *SSHServerState) GetSessions() []*SSHSessionInfo { + if x != nil { + return x.Sessions + } + return nil +} + // FullStatus contains the full state held by the Status instance type FullStatus struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1776,13 +2002,14 @@ type FullStatus struct { NumberOfForwardingRules int32 `protobuf:"varint,8,opt,name=NumberOfForwardingRules,proto3" json:"NumberOfForwardingRules,omitempty"` Events []*SystemEvent `protobuf:"bytes,7,rep,name=events,proto3" json:"events,omitempty"` LazyConnectionEnabled bool `protobuf:"varint,9,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + SshServerState *SSHServerState `protobuf:"bytes,10,opt,name=sshServerState,proto3" json:"sshServerState,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *FullStatus) Reset() { *x = FullStatus{} - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1794,7 +2021,7 @@ func (x *FullStatus) String() string { func (*FullStatus) ProtoMessage() {} func (x *FullStatus) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1807,7 +2034,7 @@ func (x *FullStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use FullStatus.ProtoReflect.Descriptor instead. func (*FullStatus) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{19} + return file_daemon_proto_rawDescGZIP(), []int{21} } func (x *FullStatus) GetManagementState() *ManagementState { @@ -1873,6 +2100,13 @@ func (x *FullStatus) GetLazyConnectionEnabled() bool { return false } +func (x *FullStatus) GetSshServerState() *SSHServerState { + if x != nil { + return x.SshServerState + } + return nil +} + // Networks type ListNetworksRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1882,7 +2116,7 @@ type ListNetworksRequest struct { func (x *ListNetworksRequest) Reset() { *x = ListNetworksRequest{} - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1894,7 +2128,7 @@ func (x *ListNetworksRequest) String() string { func (*ListNetworksRequest) ProtoMessage() {} func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1907,7 +2141,7 @@ func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. func (*ListNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{20} + return file_daemon_proto_rawDescGZIP(), []int{22} } type ListNetworksResponse struct { @@ -1919,7 +2153,7 @@ type ListNetworksResponse struct { func (x *ListNetworksResponse) Reset() { *x = ListNetworksResponse{} - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1931,7 +2165,7 @@ func (x *ListNetworksResponse) String() string { func (*ListNetworksResponse) ProtoMessage() {} func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1944,7 +2178,7 @@ func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. func (*ListNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{21} + return file_daemon_proto_rawDescGZIP(), []int{23} } func (x *ListNetworksResponse) GetRoutes() []*Network { @@ -1965,7 +2199,7 @@ type SelectNetworksRequest struct { func (x *SelectNetworksRequest) Reset() { *x = SelectNetworksRequest{} - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1977,7 +2211,7 @@ func (x *SelectNetworksRequest) String() string { func (*SelectNetworksRequest) ProtoMessage() {} func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1990,7 +2224,7 @@ func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{22} + return file_daemon_proto_rawDescGZIP(), []int{24} } func (x *SelectNetworksRequest) GetNetworkIDs() []string { @@ -2022,7 +2256,7 @@ type SelectNetworksResponse struct { func (x *SelectNetworksResponse) Reset() { *x = SelectNetworksResponse{} - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2034,7 +2268,7 @@ func (x *SelectNetworksResponse) String() string { func (*SelectNetworksResponse) ProtoMessage() {} func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2047,7 +2281,7 @@ func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{23} + return file_daemon_proto_rawDescGZIP(), []int{25} } type IPList struct { @@ -2059,7 +2293,7 @@ type IPList struct { func (x *IPList) Reset() { *x = IPList{} - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2071,7 +2305,7 @@ func (x *IPList) String() string { func (*IPList) ProtoMessage() {} func (x *IPList) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2084,7 +2318,7 @@ func (x *IPList) ProtoReflect() protoreflect.Message { // Deprecated: Use IPList.ProtoReflect.Descriptor instead. func (*IPList) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{24} + return file_daemon_proto_rawDescGZIP(), []int{26} } func (x *IPList) GetIps() []string { @@ -2107,7 +2341,7 @@ type Network struct { func (x *Network) Reset() { *x = Network{} - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2119,7 +2353,7 @@ func (x *Network) String() string { func (*Network) ProtoMessage() {} func (x *Network) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2132,7 +2366,7 @@ func (x *Network) ProtoReflect() protoreflect.Message { // Deprecated: Use Network.ProtoReflect.Descriptor instead. func (*Network) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{25} + return file_daemon_proto_rawDescGZIP(), []int{27} } func (x *Network) GetID() string { @@ -2184,7 +2418,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2196,7 +2430,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2209,7 +2443,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26} + return file_daemon_proto_rawDescGZIP(), []int{28} } func (x *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -2266,7 +2500,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2278,7 +2512,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2291,7 +2525,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{27} + return file_daemon_proto_rawDescGZIP(), []int{29} } func (x *ForwardingRule) GetProtocol() string { @@ -2338,7 +2572,7 @@ type ForwardingRulesResponse struct { func (x *ForwardingRulesResponse) Reset() { *x = ForwardingRulesResponse{} - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2350,7 +2584,7 @@ func (x *ForwardingRulesResponse) String() string { func (*ForwardingRulesResponse) ProtoMessage() {} func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2363,7 +2597,7 @@ func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRulesResponse.ProtoReflect.Descriptor instead. func (*ForwardingRulesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28} + return file_daemon_proto_rawDescGZIP(), []int{30} } func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule { @@ -2387,7 +2621,7 @@ type DebugBundleRequest struct { func (x *DebugBundleRequest) Reset() { *x = DebugBundleRequest{} - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2399,7 +2633,7 @@ func (x *DebugBundleRequest) String() string { func (*DebugBundleRequest) ProtoMessage() {} func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2412,7 +2646,7 @@ func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead. func (*DebugBundleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{29} + return file_daemon_proto_rawDescGZIP(), []int{31} } func (x *DebugBundleRequest) GetAnonymize() bool { @@ -2461,7 +2695,7 @@ type DebugBundleResponse struct { func (x *DebugBundleResponse) Reset() { *x = DebugBundleResponse{} - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2473,7 +2707,7 @@ func (x *DebugBundleResponse) String() string { func (*DebugBundleResponse) ProtoMessage() {} func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2486,7 +2720,7 @@ func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead. func (*DebugBundleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{30} + return file_daemon_proto_rawDescGZIP(), []int{32} } func (x *DebugBundleResponse) GetPath() string { @@ -2518,7 +2752,7 @@ type GetLogLevelRequest struct { func (x *GetLogLevelRequest) Reset() { *x = GetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2530,7 +2764,7 @@ func (x *GetLogLevelRequest) String() string { func (*GetLogLevelRequest) ProtoMessage() {} func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2543,7 +2777,7 @@ func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead. func (*GetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{31} + return file_daemon_proto_rawDescGZIP(), []int{33} } type GetLogLevelResponse struct { @@ -2555,7 +2789,7 @@ type GetLogLevelResponse struct { func (x *GetLogLevelResponse) Reset() { *x = GetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2567,7 +2801,7 @@ func (x *GetLogLevelResponse) String() string { func (*GetLogLevelResponse) ProtoMessage() {} func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2580,7 +2814,7 @@ func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead. func (*GetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{32} + return file_daemon_proto_rawDescGZIP(), []int{34} } func (x *GetLogLevelResponse) GetLevel() LogLevel { @@ -2599,7 +2833,7 @@ type SetLogLevelRequest struct { func (x *SetLogLevelRequest) Reset() { *x = SetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2611,7 +2845,7 @@ func (x *SetLogLevelRequest) String() string { func (*SetLogLevelRequest) ProtoMessage() {} func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[35] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2624,7 +2858,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead. func (*SetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{33} + return file_daemon_proto_rawDescGZIP(), []int{35} } func (x *SetLogLevelRequest) GetLevel() LogLevel { @@ -2642,7 +2876,7 @@ type SetLogLevelResponse struct { func (x *SetLogLevelResponse) Reset() { *x = SetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2654,7 +2888,7 @@ func (x *SetLogLevelResponse) String() string { func (*SetLogLevelResponse) ProtoMessage() {} func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[36] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2667,7 +2901,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead. func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{34} + return file_daemon_proto_rawDescGZIP(), []int{36} } // State represents a daemon state entry @@ -2680,7 +2914,7 @@ type State struct { func (x *State) Reset() { *x = State{} - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2692,7 +2926,7 @@ func (x *State) String() string { func (*State) ProtoMessage() {} func (x *State) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2705,7 +2939,7 @@ func (x *State) ProtoReflect() protoreflect.Message { // Deprecated: Use State.ProtoReflect.Descriptor instead. func (*State) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{35} + return file_daemon_proto_rawDescGZIP(), []int{37} } func (x *State) GetName() string { @@ -2724,7 +2958,7 @@ type ListStatesRequest struct { func (x *ListStatesRequest) Reset() { *x = ListStatesRequest{} - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2736,7 +2970,7 @@ func (x *ListStatesRequest) String() string { func (*ListStatesRequest) ProtoMessage() {} func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2749,7 +2983,7 @@ func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead. func (*ListStatesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{36} + return file_daemon_proto_rawDescGZIP(), []int{38} } // ListStatesResponse contains a list of states @@ -2762,7 +2996,7 @@ type ListStatesResponse struct { func (x *ListStatesResponse) Reset() { *x = ListStatesResponse{} - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2774,7 +3008,7 @@ func (x *ListStatesResponse) String() string { func (*ListStatesResponse) ProtoMessage() {} func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2787,7 +3021,7 @@ func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead. func (*ListStatesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{37} + return file_daemon_proto_rawDescGZIP(), []int{39} } func (x *ListStatesResponse) GetStates() []*State { @@ -2808,7 +3042,7 @@ type CleanStateRequest struct { func (x *CleanStateRequest) Reset() { *x = CleanStateRequest{} - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2820,7 +3054,7 @@ func (x *CleanStateRequest) String() string { func (*CleanStateRequest) ProtoMessage() {} func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2833,7 +3067,7 @@ func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead. func (*CleanStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{38} + return file_daemon_proto_rawDescGZIP(), []int{40} } func (x *CleanStateRequest) GetStateName() string { @@ -2860,7 +3094,7 @@ type CleanStateResponse struct { func (x *CleanStateResponse) Reset() { *x = CleanStateResponse{} - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2872,7 +3106,7 @@ func (x *CleanStateResponse) String() string { func (*CleanStateResponse) ProtoMessage() {} func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2885,7 +3119,7 @@ func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead. func (*CleanStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{39} + return file_daemon_proto_rawDescGZIP(), []int{41} } func (x *CleanStateResponse) GetCleanedStates() int32 { @@ -2906,7 +3140,7 @@ type DeleteStateRequest struct { func (x *DeleteStateRequest) Reset() { *x = DeleteStateRequest{} - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2918,7 +3152,7 @@ func (x *DeleteStateRequest) String() string { func (*DeleteStateRequest) ProtoMessage() {} func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2931,7 +3165,7 @@ func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead. func (*DeleteStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{40} + return file_daemon_proto_rawDescGZIP(), []int{42} } func (x *DeleteStateRequest) GetStateName() string { @@ -2958,7 +3192,7 @@ type DeleteStateResponse struct { func (x *DeleteStateResponse) Reset() { *x = DeleteStateResponse{} - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2970,7 +3204,7 @@ func (x *DeleteStateResponse) String() string { func (*DeleteStateResponse) ProtoMessage() {} func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2983,7 +3217,7 @@ func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead. func (*DeleteStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{41} + return file_daemon_proto_rawDescGZIP(), []int{43} } func (x *DeleteStateResponse) GetDeletedStates() int32 { @@ -3002,7 +3236,7 @@ type SetSyncResponsePersistenceRequest struct { func (x *SetSyncResponsePersistenceRequest) Reset() { *x = SetSyncResponsePersistenceRequest{} - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3014,7 +3248,7 @@ func (x *SetSyncResponsePersistenceRequest) String() string { func (*SetSyncResponsePersistenceRequest) ProtoMessage() {} func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3027,7 +3261,7 @@ func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceRequest.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{42} + return file_daemon_proto_rawDescGZIP(), []int{44} } func (x *SetSyncResponsePersistenceRequest) GetEnabled() bool { @@ -3045,7 +3279,7 @@ type SetSyncResponsePersistenceResponse struct { func (x *SetSyncResponsePersistenceResponse) Reset() { *x = SetSyncResponsePersistenceResponse{} - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[45] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3057,7 +3291,7 @@ func (x *SetSyncResponsePersistenceResponse) String() string { func (*SetSyncResponsePersistenceResponse) ProtoMessage() {} func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[45] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3070,7 +3304,7 @@ func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceResponse.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{43} + return file_daemon_proto_rawDescGZIP(), []int{45} } type TCPFlags struct { @@ -3087,7 +3321,7 @@ type TCPFlags struct { func (x *TCPFlags) Reset() { *x = TCPFlags{} - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[46] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3099,7 +3333,7 @@ func (x *TCPFlags) String() string { func (*TCPFlags) ProtoMessage() {} func (x *TCPFlags) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[46] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3112,7 +3346,7 @@ func (x *TCPFlags) ProtoReflect() protoreflect.Message { // Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead. func (*TCPFlags) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{44} + return file_daemon_proto_rawDescGZIP(), []int{46} } func (x *TCPFlags) GetSyn() bool { @@ -3174,7 +3408,7 @@ type TracePacketRequest struct { func (x *TracePacketRequest) Reset() { *x = TracePacketRequest{} - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[47] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3186,7 +3420,7 @@ func (x *TracePacketRequest) String() string { func (*TracePacketRequest) ProtoMessage() {} func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[47] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3199,7 +3433,7 @@ func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead. func (*TracePacketRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{45} + return file_daemon_proto_rawDescGZIP(), []int{47} } func (x *TracePacketRequest) GetSourceIp() string { @@ -3277,7 +3511,7 @@ type TraceStage struct { func (x *TraceStage) Reset() { *x = TraceStage{} - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[48] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3289,7 +3523,7 @@ func (x *TraceStage) String() string { func (*TraceStage) ProtoMessage() {} func (x *TraceStage) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[48] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3302,7 +3536,7 @@ func (x *TraceStage) ProtoReflect() protoreflect.Message { // Deprecated: Use TraceStage.ProtoReflect.Descriptor instead. func (*TraceStage) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{46} + return file_daemon_proto_rawDescGZIP(), []int{48} } func (x *TraceStage) GetName() string { @@ -3343,7 +3577,7 @@ type TracePacketResponse struct { func (x *TracePacketResponse) Reset() { *x = TracePacketResponse{} - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[49] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3355,7 +3589,7 @@ func (x *TracePacketResponse) String() string { func (*TracePacketResponse) ProtoMessage() {} func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[49] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3368,7 +3602,7 @@ func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead. func (*TracePacketResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{47} + return file_daemon_proto_rawDescGZIP(), []int{49} } func (x *TracePacketResponse) GetStages() []*TraceStage { @@ -3393,7 +3627,7 @@ type SubscribeRequest struct { func (x *SubscribeRequest) Reset() { *x = SubscribeRequest{} - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[50] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3405,7 +3639,7 @@ func (x *SubscribeRequest) String() string { func (*SubscribeRequest) ProtoMessage() {} func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[50] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3418,7 +3652,7 @@ func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. func (*SubscribeRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{48} + return file_daemon_proto_rawDescGZIP(), []int{50} } type SystemEvent struct { @@ -3436,7 +3670,7 @@ type SystemEvent struct { func (x *SystemEvent) Reset() { *x = SystemEvent{} - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[51] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3448,7 +3682,7 @@ func (x *SystemEvent) String() string { func (*SystemEvent) ProtoMessage() {} func (x *SystemEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[51] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3461,7 +3695,7 @@ func (x *SystemEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use SystemEvent.ProtoReflect.Descriptor instead. func (*SystemEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49} + return file_daemon_proto_rawDescGZIP(), []int{51} } func (x *SystemEvent) GetId() string { @@ -3521,7 +3755,7 @@ type GetEventsRequest struct { func (x *GetEventsRequest) Reset() { *x = GetEventsRequest{} - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[52] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3533,7 +3767,7 @@ func (x *GetEventsRequest) String() string { func (*GetEventsRequest) ProtoMessage() {} func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[52] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3546,7 +3780,7 @@ func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsRequest.ProtoReflect.Descriptor instead. func (*GetEventsRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{50} + return file_daemon_proto_rawDescGZIP(), []int{52} } type GetEventsResponse struct { @@ -3558,7 +3792,7 @@ type GetEventsResponse struct { func (x *GetEventsResponse) Reset() { *x = GetEventsResponse{} - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[53] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3570,7 +3804,7 @@ func (x *GetEventsResponse) String() string { func (*GetEventsResponse) ProtoMessage() {} func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[53] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3583,7 +3817,7 @@ func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsResponse.ProtoReflect.Descriptor instead. func (*GetEventsResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51} + return file_daemon_proto_rawDescGZIP(), []int{53} } func (x *GetEventsResponse) GetEvents() []*SystemEvent { @@ -3603,7 +3837,7 @@ type SwitchProfileRequest struct { func (x *SwitchProfileRequest) Reset() { *x = SwitchProfileRequest{} - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3615,7 +3849,7 @@ func (x *SwitchProfileRequest) String() string { func (*SwitchProfileRequest) ProtoMessage() {} func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[54] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3628,7 +3862,7 @@ func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileRequest.ProtoReflect.Descriptor instead. func (*SwitchProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{52} + return file_daemon_proto_rawDescGZIP(), []int{54} } func (x *SwitchProfileRequest) GetProfileName() string { @@ -3653,7 +3887,7 @@ type SwitchProfileResponse struct { func (x *SwitchProfileResponse) Reset() { *x = SwitchProfileResponse{} - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[55] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3665,7 +3899,7 @@ func (x *SwitchProfileResponse) String() string { func (*SwitchProfileResponse) ProtoMessage() {} func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[55] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3678,7 +3912,7 @@ func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileResponse.ProtoReflect.Descriptor instead. func (*SwitchProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53} + return file_daemon_proto_rawDescGZIP(), []int{55} } type SetConfigRequest struct { @@ -3711,16 +3945,22 @@ type SetConfigRequest struct { ExtraIFaceBlacklist []string `protobuf:"bytes,24,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` DnsLabels []string `protobuf:"bytes,25,rep,name=dns_labels,json=dnsLabels,proto3" json:"dns_labels,omitempty"` // cleanDNSLabels clean map list of DNS labels. - CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` - DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` - Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` + DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` + Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` + EnableSSHRoot *bool `protobuf:"varint,29,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` + SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *SetConfigRequest) Reset() { *x = SetConfigRequest{} - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[56] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3732,7 +3972,7 @@ func (x *SetConfigRequest) String() string { func (*SetConfigRequest) ProtoMessage() {} func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[56] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3745,7 +3985,7 @@ func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead. func (*SetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{54} + return file_daemon_proto_rawDescGZIP(), []int{56} } func (x *SetConfigRequest) GetUsername() string { @@ -3944,6 +4184,48 @@ func (x *SetConfigRequest) GetMtu() int64 { return 0 } +func (x *SetConfigRequest) GetEnableSSHRoot() bool { + if x != nil && x.EnableSSHRoot != nil { + return *x.EnableSSHRoot + } + return false +} + +func (x *SetConfigRequest) GetEnableSSHSFTP() bool { + if x != nil && x.EnableSSHSFTP != nil { + return *x.EnableSSHSFTP + } + return false +} + +func (x *SetConfigRequest) GetEnableSSHLocalPortForwarding() bool { + if x != nil && x.EnableSSHLocalPortForwarding != nil { + return *x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *SetConfigRequest) GetEnableSSHRemotePortForwarding() bool { + if x != nil && x.EnableSSHRemotePortForwarding != nil { + return *x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *SetConfigRequest) GetDisableSSHAuth() bool { + if x != nil && x.DisableSSHAuth != nil { + return *x.DisableSSHAuth + } + return false +} + +func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 { + if x != nil && x.SshJWTCacheTTL != nil { + return *x.SshJWTCacheTTL + } + return 0 +} + type SetConfigResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -3952,7 +4234,7 @@ type SetConfigResponse struct { func (x *SetConfigResponse) Reset() { *x = SetConfigResponse{} - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[57] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3964,7 +4246,7 @@ func (x *SetConfigResponse) String() string { func (*SetConfigResponse) ProtoMessage() {} func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[57] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3977,7 +4259,7 @@ func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead. func (*SetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{55} + return file_daemon_proto_rawDescGZIP(), []int{57} } type AddProfileRequest struct { @@ -3990,7 +4272,7 @@ type AddProfileRequest struct { func (x *AddProfileRequest) Reset() { *x = AddProfileRequest{} - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[58] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4002,7 +4284,7 @@ func (x *AddProfileRequest) String() string { func (*AddProfileRequest) ProtoMessage() {} func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[58] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4015,7 +4297,7 @@ func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileRequest.ProtoReflect.Descriptor instead. func (*AddProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{56} + return file_daemon_proto_rawDescGZIP(), []int{58} } func (x *AddProfileRequest) GetUsername() string { @@ -4040,7 +4322,7 @@ type AddProfileResponse struct { func (x *AddProfileResponse) Reset() { *x = AddProfileResponse{} - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[59] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4052,7 +4334,7 @@ func (x *AddProfileResponse) String() string { func (*AddProfileResponse) ProtoMessage() {} func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[59] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4065,7 +4347,7 @@ func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileResponse.ProtoReflect.Descriptor instead. func (*AddProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{57} + return file_daemon_proto_rawDescGZIP(), []int{59} } type RemoveProfileRequest struct { @@ -4078,7 +4360,7 @@ type RemoveProfileRequest struct { func (x *RemoveProfileRequest) Reset() { *x = RemoveProfileRequest{} - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[60] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4090,7 +4372,7 @@ func (x *RemoveProfileRequest) String() string { func (*RemoveProfileRequest) ProtoMessage() {} func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[60] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4103,7 +4385,7 @@ func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileRequest.ProtoReflect.Descriptor instead. func (*RemoveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{58} + return file_daemon_proto_rawDescGZIP(), []int{60} } func (x *RemoveProfileRequest) GetUsername() string { @@ -4128,7 +4410,7 @@ type RemoveProfileResponse struct { func (x *RemoveProfileResponse) Reset() { *x = RemoveProfileResponse{} - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[61] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4140,7 +4422,7 @@ func (x *RemoveProfileResponse) String() string { func (*RemoveProfileResponse) ProtoMessage() {} func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[61] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4153,7 +4435,7 @@ func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileResponse.ProtoReflect.Descriptor instead. func (*RemoveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{59} + return file_daemon_proto_rawDescGZIP(), []int{61} } type ListProfilesRequest struct { @@ -4165,7 +4447,7 @@ type ListProfilesRequest struct { func (x *ListProfilesRequest) Reset() { *x = ListProfilesRequest{} - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[62] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4177,7 +4459,7 @@ func (x *ListProfilesRequest) String() string { func (*ListProfilesRequest) ProtoMessage() {} func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[62] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4190,7 +4472,7 @@ func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesRequest.ProtoReflect.Descriptor instead. func (*ListProfilesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{60} + return file_daemon_proto_rawDescGZIP(), []int{62} } func (x *ListProfilesRequest) GetUsername() string { @@ -4209,7 +4491,7 @@ type ListProfilesResponse struct { func (x *ListProfilesResponse) Reset() { *x = ListProfilesResponse{} - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[63] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4221,7 +4503,7 @@ func (x *ListProfilesResponse) String() string { func (*ListProfilesResponse) ProtoMessage() {} func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[63] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4234,7 +4516,7 @@ func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesResponse.ProtoReflect.Descriptor instead. func (*ListProfilesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{61} + return file_daemon_proto_rawDescGZIP(), []int{63} } func (x *ListProfilesResponse) GetProfiles() []*Profile { @@ -4254,7 +4536,7 @@ type Profile struct { func (x *Profile) Reset() { *x = Profile{} - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[64] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4266,7 +4548,7 @@ func (x *Profile) String() string { func (*Profile) ProtoMessage() {} func (x *Profile) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[64] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4279,7 +4561,7 @@ func (x *Profile) ProtoReflect() protoreflect.Message { // Deprecated: Use Profile.ProtoReflect.Descriptor instead. func (*Profile) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{62} + return file_daemon_proto_rawDescGZIP(), []int{64} } func (x *Profile) GetName() string { @@ -4304,7 +4586,7 @@ type GetActiveProfileRequest struct { func (x *GetActiveProfileRequest) Reset() { *x = GetActiveProfileRequest{} - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[65] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4316,7 +4598,7 @@ func (x *GetActiveProfileRequest) String() string { func (*GetActiveProfileRequest) ProtoMessage() {} func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[65] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4329,7 +4611,7 @@ func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileRequest.ProtoReflect.Descriptor instead. func (*GetActiveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{63} + return file_daemon_proto_rawDescGZIP(), []int{65} } type GetActiveProfileResponse struct { @@ -4342,7 +4624,7 @@ type GetActiveProfileResponse struct { func (x *GetActiveProfileResponse) Reset() { *x = GetActiveProfileResponse{} - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[66] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4354,7 +4636,7 @@ func (x *GetActiveProfileResponse) String() string { func (*GetActiveProfileResponse) ProtoMessage() {} func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[66] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4367,7 +4649,7 @@ func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileResponse.ProtoReflect.Descriptor instead. func (*GetActiveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{64} + return file_daemon_proto_rawDescGZIP(), []int{66} } func (x *GetActiveProfileResponse) GetProfileName() string { @@ -4394,7 +4676,7 @@ type LogoutRequest struct { func (x *LogoutRequest) Reset() { *x = LogoutRequest{} - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[67] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4406,7 +4688,7 @@ func (x *LogoutRequest) String() string { func (*LogoutRequest) ProtoMessage() {} func (x *LogoutRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[67] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4419,7 +4701,7 @@ func (x *LogoutRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead. func (*LogoutRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{65} + return file_daemon_proto_rawDescGZIP(), []int{67} } func (x *LogoutRequest) GetProfileName() string { @@ -4444,7 +4726,7 @@ type LogoutResponse struct { func (x *LogoutResponse) Reset() { *x = LogoutResponse{} - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4456,7 +4738,7 @@ func (x *LogoutResponse) String() string { func (*LogoutResponse) ProtoMessage() {} func (x *LogoutResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4469,7 +4751,7 @@ func (x *LogoutResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead. func (*LogoutResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{66} + return file_daemon_proto_rawDescGZIP(), []int{68} } type GetFeaturesRequest struct { @@ -4480,7 +4762,7 @@ type GetFeaturesRequest struct { func (x *GetFeaturesRequest) Reset() { *x = GetFeaturesRequest{} - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[69] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4492,7 +4774,7 @@ func (x *GetFeaturesRequest) String() string { func (*GetFeaturesRequest) ProtoMessage() {} func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[69] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4505,7 +4787,7 @@ func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead. func (*GetFeaturesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{67} + return file_daemon_proto_rawDescGZIP(), []int{69} } type GetFeaturesResponse struct { @@ -4518,7 +4800,7 @@ type GetFeaturesResponse struct { func (x *GetFeaturesResponse) Reset() { *x = GetFeaturesResponse{} - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[70] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4530,7 +4812,7 @@ func (x *GetFeaturesResponse) String() string { func (*GetFeaturesResponse) ProtoMessage() {} func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[70] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4543,7 +4825,7 @@ func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead. func (*GetFeaturesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{68} + return file_daemon_proto_rawDescGZIP(), []int{70} } func (x *GetFeaturesResponse) GetDisableProfiles() bool { @@ -4560,6 +4842,390 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool { return false } +// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer +type GetPeerSSHHostKeyRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // peer IP address or FQDN to get SSH host key for + PeerAddress string `protobuf:"bytes,1,opt,name=peerAddress,proto3" json:"peerAddress,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetPeerSSHHostKeyRequest) Reset() { + *x = GetPeerSSHHostKeyRequest{} + mi := &file_daemon_proto_msgTypes[71] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetPeerSSHHostKeyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} + +func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[71] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. +func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{71} +} + +func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { + if x != nil { + return x.PeerAddress + } + return "" +} + +// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer +type GetPeerSSHHostKeyResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname") + SshHostKey []byte `protobuf:"bytes,1,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"` + // peer IP address + PeerIP string `protobuf:"bytes,2,opt,name=peerIP,proto3" json:"peerIP,omitempty"` + // peer FQDN + PeerFQDN string `protobuf:"bytes,3,opt,name=peerFQDN,proto3" json:"peerFQDN,omitempty"` + // indicates if the SSH host key was found + Found bool `protobuf:"varint,4,opt,name=found,proto3" json:"found,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetPeerSSHHostKeyResponse) Reset() { + *x = GetPeerSSHHostKeyResponse{} + mi := &file_daemon_proto_msgTypes[72] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetPeerSSHHostKeyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} + +func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[72] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. +func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{72} +} + +func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { + if x != nil { + return x.SshHostKey + } + return nil +} + +func (x *GetPeerSSHHostKeyResponse) GetPeerIP() string { + if x != nil { + return x.PeerIP + } + return "" +} + +func (x *GetPeerSSHHostKeyResponse) GetPeerFQDN() string { + if x != nil { + return x.PeerFQDN + } + return "" +} + +func (x *GetPeerSSHHostKeyResponse) GetFound() bool { + if x != nil { + return x.Found + } + return false +} + +// RequestJWTAuthRequest for initiating JWT authentication flow +type RequestJWTAuthRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // hint for OIDC login_hint parameter (typically email address) + Hint *string `protobuf:"bytes,1,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RequestJWTAuthRequest) Reset() { + *x = RequestJWTAuthRequest{} + mi := &file_daemon_proto_msgTypes[73] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RequestJWTAuthRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RequestJWTAuthRequest) ProtoMessage() {} + +func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[73] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. +func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{73} +} + +func (x *RequestJWTAuthRequest) GetHint() string { + if x != nil && x.Hint != nil { + return *x.Hint + } + return "" +} + +// RequestJWTAuthResponse contains authentication flow information +type RequestJWTAuthResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // verification URI for user authentication + VerificationURI string `protobuf:"bytes,1,opt,name=verificationURI,proto3" json:"verificationURI,omitempty"` + // complete verification URI (with embedded user code) + VerificationURIComplete string `protobuf:"bytes,2,opt,name=verificationURIComplete,proto3" json:"verificationURIComplete,omitempty"` + // user code to enter on verification URI + UserCode string `protobuf:"bytes,3,opt,name=userCode,proto3" json:"userCode,omitempty"` + // device code for polling + DeviceCode string `protobuf:"bytes,4,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"` + // expiration time in seconds + ExpiresIn int64 `protobuf:"varint,5,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"` + // if a cached token is available, it will be returned here + CachedToken string `protobuf:"bytes,6,opt,name=cachedToken,proto3" json:"cachedToken,omitempty"` + // maximum age of JWT tokens in seconds (from management server) + MaxTokenAge int64 `protobuf:"varint,7,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RequestJWTAuthResponse) Reset() { + *x = RequestJWTAuthResponse{} + mi := &file_daemon_proto_msgTypes[74] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RequestJWTAuthResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RequestJWTAuthResponse) ProtoMessage() {} + +func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[74] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. +func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{74} +} + +func (x *RequestJWTAuthResponse) GetVerificationURI() string { + if x != nil { + return x.VerificationURI + } + return "" +} + +func (x *RequestJWTAuthResponse) GetVerificationURIComplete() string { + if x != nil { + return x.VerificationURIComplete + } + return "" +} + +func (x *RequestJWTAuthResponse) GetUserCode() string { + if x != nil { + return x.UserCode + } + return "" +} + +func (x *RequestJWTAuthResponse) GetDeviceCode() string { + if x != nil { + return x.DeviceCode + } + return "" +} + +func (x *RequestJWTAuthResponse) GetExpiresIn() int64 { + if x != nil { + return x.ExpiresIn + } + return 0 +} + +func (x *RequestJWTAuthResponse) GetCachedToken() string { + if x != nil { + return x.CachedToken + } + return "" +} + +func (x *RequestJWTAuthResponse) GetMaxTokenAge() int64 { + if x != nil { + return x.MaxTokenAge + } + return 0 +} + +// WaitJWTTokenRequest for waiting for authentication completion +type WaitJWTTokenRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // device code from RequestJWTAuthResponse + DeviceCode string `protobuf:"bytes,1,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"` + // user code for verification + UserCode string `protobuf:"bytes,2,opt,name=userCode,proto3" json:"userCode,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WaitJWTTokenRequest) Reset() { + *x = WaitJWTTokenRequest{} + mi := &file_daemon_proto_msgTypes[75] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WaitJWTTokenRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WaitJWTTokenRequest) ProtoMessage() {} + +func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[75] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. +func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{75} +} + +func (x *WaitJWTTokenRequest) GetDeviceCode() string { + if x != nil { + return x.DeviceCode + } + return "" +} + +func (x *WaitJWTTokenRequest) GetUserCode() string { + if x != nil { + return x.UserCode + } + return "" +} + +// WaitJWTTokenResponse contains the JWT token after authentication +type WaitJWTTokenResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // JWT token (access token or ID token) + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` + // token type (e.g., "Bearer") + TokenType string `protobuf:"bytes,2,opt,name=tokenType,proto3" json:"tokenType,omitempty"` + // expiration time in seconds + ExpiresIn int64 `protobuf:"varint,3,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WaitJWTTokenResponse) Reset() { + *x = WaitJWTTokenResponse{} + mi := &file_daemon_proto_msgTypes[76] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WaitJWTTokenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WaitJWTTokenResponse) ProtoMessage() {} + +func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[76] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. +func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{76} +} + +func (x *WaitJWTTokenResponse) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +func (x *WaitJWTTokenResponse) GetTokenType() string { + if x != nil { + return x.TokenType + } + return "" +} + +func (x *WaitJWTTokenResponse) GetExpiresIn() int64 { + if x != nil { + return x.ExpiresIn + } + return 0 +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -4570,7 +5236,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4582,7 +5248,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[78] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4595,7 +5261,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26, 0} + return file_daemon_proto_rawDescGZIP(), []int{28, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -4617,7 +5283,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xe5\x0e\n" + + "\fEmptyRequest\"\xb6\x12\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -4655,7 +5321,13 @@ const file_daemon_proto_rawDesc = "" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" + - "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" + + "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01\x12)\n" + + "\renableSSHRoot\x18\" \x01(\bH\x15R\renableSSHRoot\x88\x01\x01\x12)\n" + + "\renableSSHSFTP\x18# \x01(\bH\x16R\renableSSHSFTP\x88\x01\x01\x12G\n" + + "\x1cenableSSHLocalPortForwarding\x18$ \x01(\bH\x17R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + + "\x1denableSSHRemotePortForwarding\x18% \x01(\bH\x18R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + + "\x0edisableSSHAuth\x18& \x01(\bH\x19R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + + "\x0esshJWTCacheTTL\x18' \x01(\x05H\x1aR\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4676,7 +5348,13 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_usernameB\x06\n" + "\x04_mtuB\a\n" + - "\x05_hint\"\xb5\x01\n" + + "\x05_hintB\x10\n" + + "\x0e_enableSSHRootB\x10\n" + + "\x0e_enableSSHSFTPB\x1f\n" + + "\x1d_enableSSHLocalPortForwardingB \n" + + "\x1e_enableSSHRemotePortForwardingB\x11\n" + + "\x0f_disableSSHAuthB\x11\n" + + "\x0f_sshJWTCacheTTL\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + @@ -4709,7 +5387,7 @@ const file_daemon_proto_rawDesc = "" + "\fDownResponse\"P\n" + "\x10GetConfigRequest\x12 \n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + - "\busername\x18\x02 \x01(\tR\busername\"\xb5\x06\n" + + "\busername\x18\x02 \x01(\tR\busername\"\xdb\b\n" + "\x11GetConfigResponse\x12$\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\n" + @@ -4734,7 +5412,13 @@ const file_daemon_proto_rawDesc = "" + "disableDns\x122\n" + "\x15disable_client_routes\x18\x12 \x01(\bR\x13disableClientRoutes\x122\n" + "\x15disable_server_routes\x18\x13 \x01(\bR\x13disableServerRoutes\x12(\n" + - "\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\"\xde\x05\n" + + "\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\x12$\n" + + "\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" + + "\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" + + "\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" + + "\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" + + "\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" + + "\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\"\xfe\x05\n" + "\tPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" + @@ -4755,7 +5439,10 @@ const file_daemon_proto_rawDesc = "" + "\x10rosenpassEnabled\x18\x0f \x01(\bR\x10rosenpassEnabled\x12\x1a\n" + "\bnetworks\x18\x10 \x03(\tR\bnetworks\x123\n" + "\alatency\x18\x11 \x01(\v2\x19.google.protobuf.DurationR\alatency\x12\"\n" + - "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\"\xf0\x01\n" + + "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\x12\x1e\n" + + "\n" + + "sshHostKey\x18\x13 \x01(\fR\n" + + "sshHostKey\"\xf0\x01\n" + "\x0eLocalPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" + @@ -4781,7 +5468,15 @@ const file_daemon_proto_rawDesc = "" + "\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" + "\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" + "\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" + - "\x05error\x18\x04 \x01(\tR\x05error\"\xef\x03\n" + + "\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" + + "\x0eSSHSessionInfo\x12\x1a\n" + + "\busername\x18\x01 \x01(\tR\busername\x12$\n" + + "\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" + + "\acommand\x18\x03 \x01(\tR\acommand\x12 \n" + + "\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" + + "\x0eSSHServerState\x12\x18\n" + + "\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" + + "\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" + "\n" + "FullStatus\x12A\n" + "\x0fmanagementState\x18\x01 \x01(\v2\x17.daemon.ManagementStateR\x0fmanagementState\x125\n" + @@ -4793,7 +5488,9 @@ const file_daemon_proto_rawDesc = "" + "dnsServers\x128\n" + "\x17NumberOfForwardingRules\x18\b \x01(\x05R\x17NumberOfForwardingRules\x12+\n" + "\x06events\x18\a \x03(\v2\x13.daemon.SystemEventR\x06events\x124\n" + - "\x15lazyConnectionEnabled\x18\t \x01(\bR\x15lazyConnectionEnabled\"\x15\n" + + "\x15lazyConnectionEnabled\x18\t \x01(\bR\x15lazyConnectionEnabled\x12>\n" + + "\x0esshServerState\x18\n" + + " \x01(\v2\x16.daemon.SSHServerStateR\x0esshServerState\"\x15\n" + "\x13ListNetworksRequest\"?\n" + "\x14ListNetworksResponse\x12'\n" + "\x06routes\x18\x01 \x03(\v2\x0f.daemon.NetworkR\x06routes\"a\n" + @@ -4934,7 +5631,7 @@ const file_daemon_proto_rawDesc = "" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + "\t_username\"\x17\n" + - "\x15SwitchProfileResponse\"\x8e\r\n" + + "\x15SwitchProfileResponse\"\xdf\x10\n" + "\x10SetConfigRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + @@ -4967,7 +5664,13 @@ const file_daemon_proto_rawDesc = "" + "dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" + "\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" + "\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01B\x13\n" + + "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01\x12)\n" + + "\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" + + "\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12G\n" + + "\x1cenableSSHLocalPortForwarding\x18\x1f \x01(\bH\x14R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + + "\x1denableSSHRemotePortForwarding\x18 \x01(\bH\x15R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + + "\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + + "\x0esshJWTCacheTTL\x18\" \x01(\x05H\x17R\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4985,7 +5688,13 @@ const file_daemon_proto_rawDesc = "" + "\x16_lazyConnectionEnabledB\x10\n" + "\x0e_block_inboundB\x13\n" + "\x11_dnsRouteIntervalB\x06\n" + - "\x04_mtu\"\x13\n" + + "\x04_mtuB\x10\n" + + "\x0e_enableSSHRootB\x10\n" + + "\x0e_enableSSHSFTPB\x1f\n" + + "\x1d_enableSSHLocalPortForwardingB \n" + + "\x1e_enableSSHRemotePortForwardingB\x11\n" + + "\x0f_disableSSHAuthB\x11\n" + + "\x0f_sshJWTCacheTTL\"\x13\n" + "\x11SetConfigResponse\"Q\n" + "\x11AddProfileRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + @@ -5015,7 +5724,38 @@ const file_daemon_proto_rawDesc = "" + "\x12GetFeaturesRequest\"x\n" + "\x13GetFeaturesResponse\x12)\n" + "\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" + - "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings*b\n" + + "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"<\n" + + "\x18GetPeerSSHHostKeyRequest\x12 \n" + + "\vpeerAddress\x18\x01 \x01(\tR\vpeerAddress\"\x85\x01\n" + + "\x19GetPeerSSHHostKeyResponse\x12\x1e\n" + + "\n" + + "sshHostKey\x18\x01 \x01(\fR\n" + + "sshHostKey\x12\x16\n" + + "\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" + + "\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" + + "\x05found\x18\x04 \x01(\bR\x05found\"9\n" + + "\x15RequestJWTAuthRequest\x12\x17\n" + + "\x04hint\x18\x01 \x01(\tH\x00R\x04hint\x88\x01\x01B\a\n" + + "\x05_hint\"\x9a\x02\n" + + "\x16RequestJWTAuthResponse\x12(\n" + + "\x0fverificationURI\x18\x01 \x01(\tR\x0fverificationURI\x128\n" + + "\x17verificationURIComplete\x18\x02 \x01(\tR\x17verificationURIComplete\x12\x1a\n" + + "\buserCode\x18\x03 \x01(\tR\buserCode\x12\x1e\n" + + "\n" + + "deviceCode\x18\x04 \x01(\tR\n" + + "deviceCode\x12\x1c\n" + + "\texpiresIn\x18\x05 \x01(\x03R\texpiresIn\x12 \n" + + "\vcachedToken\x18\x06 \x01(\tR\vcachedToken\x12 \n" + + "\vmaxTokenAge\x18\a \x01(\x03R\vmaxTokenAge\"Q\n" + + "\x13WaitJWTTokenRequest\x12\x1e\n" + + "\n" + + "deviceCode\x18\x01 \x01(\tR\n" + + "deviceCode\x12\x1a\n" + + "\buserCode\x18\x02 \x01(\tR\buserCode\"h\n" + + "\x14WaitJWTTokenResponse\x12\x14\n" + + "\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" + + "\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" + + "\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -5024,7 +5764,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\x8f\x10\n" + + "\x05TRACE\x10\a2\x8b\x12\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -5056,7 +5796,10 @@ const file_daemon_proto_rawDesc = "" + "\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" + "\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" + - "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00B\bZ\x06/protob\x06proto3" + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -5071,7 +5814,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 72) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 80) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity @@ -5095,155 +5838,171 @@ var file_daemon_proto_goTypes = []any{ (*ManagementState)(nil), // 19: daemon.ManagementState (*RelayState)(nil), // 20: daemon.RelayState (*NSGroupState)(nil), // 21: daemon.NSGroupState - (*FullStatus)(nil), // 22: daemon.FullStatus - (*ListNetworksRequest)(nil), // 23: daemon.ListNetworksRequest - (*ListNetworksResponse)(nil), // 24: daemon.ListNetworksResponse - (*SelectNetworksRequest)(nil), // 25: daemon.SelectNetworksRequest - (*SelectNetworksResponse)(nil), // 26: daemon.SelectNetworksResponse - (*IPList)(nil), // 27: daemon.IPList - (*Network)(nil), // 28: daemon.Network - (*PortInfo)(nil), // 29: daemon.PortInfo - (*ForwardingRule)(nil), // 30: daemon.ForwardingRule - (*ForwardingRulesResponse)(nil), // 31: daemon.ForwardingRulesResponse - (*DebugBundleRequest)(nil), // 32: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 33: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 34: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 35: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 36: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 37: daemon.SetLogLevelResponse - (*State)(nil), // 38: daemon.State - (*ListStatesRequest)(nil), // 39: daemon.ListStatesRequest - (*ListStatesResponse)(nil), // 40: daemon.ListStatesResponse - (*CleanStateRequest)(nil), // 41: daemon.CleanStateRequest - (*CleanStateResponse)(nil), // 42: daemon.CleanStateResponse - (*DeleteStateRequest)(nil), // 43: daemon.DeleteStateRequest - (*DeleteStateResponse)(nil), // 44: daemon.DeleteStateResponse - (*SetSyncResponsePersistenceRequest)(nil), // 45: daemon.SetSyncResponsePersistenceRequest - (*SetSyncResponsePersistenceResponse)(nil), // 46: daemon.SetSyncResponsePersistenceResponse - (*TCPFlags)(nil), // 47: daemon.TCPFlags - (*TracePacketRequest)(nil), // 48: daemon.TracePacketRequest - (*TraceStage)(nil), // 49: daemon.TraceStage - (*TracePacketResponse)(nil), // 50: daemon.TracePacketResponse - (*SubscribeRequest)(nil), // 51: daemon.SubscribeRequest - (*SystemEvent)(nil), // 52: daemon.SystemEvent - (*GetEventsRequest)(nil), // 53: daemon.GetEventsRequest - (*GetEventsResponse)(nil), // 54: daemon.GetEventsResponse - (*SwitchProfileRequest)(nil), // 55: daemon.SwitchProfileRequest - (*SwitchProfileResponse)(nil), // 56: daemon.SwitchProfileResponse - (*SetConfigRequest)(nil), // 57: daemon.SetConfigRequest - (*SetConfigResponse)(nil), // 58: daemon.SetConfigResponse - (*AddProfileRequest)(nil), // 59: daemon.AddProfileRequest - (*AddProfileResponse)(nil), // 60: daemon.AddProfileResponse - (*RemoveProfileRequest)(nil), // 61: daemon.RemoveProfileRequest - (*RemoveProfileResponse)(nil), // 62: daemon.RemoveProfileResponse - (*ListProfilesRequest)(nil), // 63: daemon.ListProfilesRequest - (*ListProfilesResponse)(nil), // 64: daemon.ListProfilesResponse - (*Profile)(nil), // 65: daemon.Profile - (*GetActiveProfileRequest)(nil), // 66: daemon.GetActiveProfileRequest - (*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse - (*LogoutRequest)(nil), // 68: daemon.LogoutRequest - (*LogoutResponse)(nil), // 69: daemon.LogoutResponse - (*GetFeaturesRequest)(nil), // 70: daemon.GetFeaturesRequest - (*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse - nil, // 72: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 73: daemon.PortInfo.Range - nil, // 74: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 75: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 76: google.protobuf.Timestamp + (*SSHSessionInfo)(nil), // 22: daemon.SSHSessionInfo + (*SSHServerState)(nil), // 23: daemon.SSHServerState + (*FullStatus)(nil), // 24: daemon.FullStatus + (*ListNetworksRequest)(nil), // 25: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 26: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 27: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 28: daemon.SelectNetworksResponse + (*IPList)(nil), // 29: daemon.IPList + (*Network)(nil), // 30: daemon.Network + (*PortInfo)(nil), // 31: daemon.PortInfo + (*ForwardingRule)(nil), // 32: daemon.ForwardingRule + (*ForwardingRulesResponse)(nil), // 33: daemon.ForwardingRulesResponse + (*DebugBundleRequest)(nil), // 34: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 35: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 36: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 37: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 38: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 39: daemon.SetLogLevelResponse + (*State)(nil), // 40: daemon.State + (*ListStatesRequest)(nil), // 41: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 42: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 43: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 44: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 45: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 46: daemon.DeleteStateResponse + (*SetSyncResponsePersistenceRequest)(nil), // 47: daemon.SetSyncResponsePersistenceRequest + (*SetSyncResponsePersistenceResponse)(nil), // 48: daemon.SetSyncResponsePersistenceResponse + (*TCPFlags)(nil), // 49: daemon.TCPFlags + (*TracePacketRequest)(nil), // 50: daemon.TracePacketRequest + (*TraceStage)(nil), // 51: daemon.TraceStage + (*TracePacketResponse)(nil), // 52: daemon.TracePacketResponse + (*SubscribeRequest)(nil), // 53: daemon.SubscribeRequest + (*SystemEvent)(nil), // 54: daemon.SystemEvent + (*GetEventsRequest)(nil), // 55: daemon.GetEventsRequest + (*GetEventsResponse)(nil), // 56: daemon.GetEventsResponse + (*SwitchProfileRequest)(nil), // 57: daemon.SwitchProfileRequest + (*SwitchProfileResponse)(nil), // 58: daemon.SwitchProfileResponse + (*SetConfigRequest)(nil), // 59: daemon.SetConfigRequest + (*SetConfigResponse)(nil), // 60: daemon.SetConfigResponse + (*AddProfileRequest)(nil), // 61: daemon.AddProfileRequest + (*AddProfileResponse)(nil), // 62: daemon.AddProfileResponse + (*RemoveProfileRequest)(nil), // 63: daemon.RemoveProfileRequest + (*RemoveProfileResponse)(nil), // 64: daemon.RemoveProfileResponse + (*ListProfilesRequest)(nil), // 65: daemon.ListProfilesRequest + (*ListProfilesResponse)(nil), // 66: daemon.ListProfilesResponse + (*Profile)(nil), // 67: daemon.Profile + (*GetActiveProfileRequest)(nil), // 68: daemon.GetActiveProfileRequest + (*GetActiveProfileResponse)(nil), // 69: daemon.GetActiveProfileResponse + (*LogoutRequest)(nil), // 70: daemon.LogoutRequest + (*LogoutResponse)(nil), // 71: daemon.LogoutResponse + (*GetFeaturesRequest)(nil), // 72: daemon.GetFeaturesRequest + (*GetFeaturesResponse)(nil), // 73: daemon.GetFeaturesResponse + (*GetPeerSSHHostKeyRequest)(nil), // 74: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 75: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 76: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 77: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 78: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 79: daemon.WaitJWTTokenResponse + nil, // 80: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 81: daemon.PortInfo.Range + nil, // 82: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 83: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 84: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 75, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 76, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 76, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 75, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 16, // 8: daemon.FullStatus.peers:type_name -> daemon.PeerState - 20, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState - 21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 72, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 73, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule - 0, // 18: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 19: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 38, // 20: daemon.ListStatesResponse.states:type_name -> daemon.State - 47, // 21: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 76, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 74, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 75, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 6, // 32: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 8, // 33: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 10, // 34: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 12, // 35: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 14, // 36: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 23, // 37: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 25, // 38: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 25, // 39: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 3, // 40: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 32, // 41: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 34, // 42: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 36, // 43: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 39, // 44: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 41, // 45: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 43, // 46: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 45, // 47: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 48, // 48: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 51, // 49: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 53, // 50: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 55, // 51: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 57, // 52: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 59, // 53: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 61, // 54: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 5, // 59: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 7, // 60: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 9, // 61: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 11, // 62: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 13, // 63: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 15, // 64: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 24, // 65: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 26, // 66: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 26, // 67: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 31, // 68: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 33, // 69: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 35, // 70: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 37, // 71: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 40, // 72: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 42, // 73: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 44, // 74: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 46, // 75: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 50, // 76: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 52, // 77: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 54, // 78: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 56, // 79: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 58, // 80: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 60, // 81: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 62, // 82: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 64, // 83: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 67, // 84: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 69, // 85: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 71, // 86: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 59, // [59:87] is the sub-list for method output_type - 31, // [31:59] is the sub-list for method input_type - 31, // [31:31] is the sub-list for extension type_name - 31, // [31:31] is the sub-list for extension extendee - 0, // [0:31] is the sub-list for field type_name + 83, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 24, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 84, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 84, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 83, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 22, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 19, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 18, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 17, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 16, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState + 20, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState + 21, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 54, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 23, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 30, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 80, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 81, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 31, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 31, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 32, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 40, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State + 49, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 51, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 1, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 2, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 84, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 82, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 54, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 83, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 67, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 29, // 32: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 4, // 33: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 6, // 34: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 8, // 35: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 10, // 36: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 12, // 37: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 14, // 38: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 25, // 39: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 27, // 40: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 27, // 41: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 3, // 42: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 34, // 43: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 36, // 44: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 38, // 45: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 41, // 46: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 43, // 47: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 45, // 48: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 47, // 49: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 50, // 50: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 53, // 51: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 55, // 52: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 57, // 53: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 59, // 54: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 61, // 55: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 63, // 56: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 65, // 57: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 68, // 58: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 70, // 59: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 72, // 60: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 74, // 61: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 76, // 62: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 78, // 63: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 5, // 64: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 7, // 65: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 9, // 66: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 11, // 67: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 13, // 68: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 15, // 69: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 26, // 70: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 28, // 71: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 28, // 72: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 33, // 73: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 35, // 74: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 37, // 75: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 39, // 76: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 42, // 77: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 44, // 78: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 46, // 79: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 48, // 80: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 52, // 81: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 54, // 82: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 56, // 83: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 58, // 84: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 60, // 85: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 62, // 86: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 64, // 87: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 66, // 88: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 69, // 89: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 71, // 90: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 73, // 91: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 75, // 92: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 77, // 93: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 79, // 94: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 64, // [64:95] is the sub-list for method output_type + 33, // [33:64] is the sub-list for method input_type + 33, // [33:33] is the sub-list for extension type_name + 33, // [33:33] is the sub-list for extension extendee + 0, // [0:33] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -5254,22 +6013,23 @@ func file_daemon_proto_init() { file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{} file_daemon_proto_msgTypes[7].OneofWrappers = []any{} - file_daemon_proto_msgTypes[26].OneofWrappers = []any{ + file_daemon_proto_msgTypes[28].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } - file_daemon_proto_msgTypes[45].OneofWrappers = []any{} - file_daemon_proto_msgTypes[46].OneofWrappers = []any{} - file_daemon_proto_msgTypes[52].OneofWrappers = []any{} + file_daemon_proto_msgTypes[47].OneofWrappers = []any{} + file_daemon_proto_msgTypes[48].OneofWrappers = []any{} file_daemon_proto_msgTypes[54].OneofWrappers = []any{} - file_daemon_proto_msgTypes[65].OneofWrappers = []any{} + file_daemon_proto_msgTypes[56].OneofWrappers = []any{} + file_daemon_proto_msgTypes[67].OneofWrappers = []any{} + file_daemon_proto_msgTypes[73].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 3, - NumMessages: 72, + NumMessages: 80, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 8d1080051..bf8553706 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -84,6 +84,15 @@ service DaemonService { rpc Logout(LogoutRequest) returns (LogoutResponse) {} rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {} + + // GetPeerSSHHostKey retrieves SSH host key for a specific peer + rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {} + + // RequestJWTAuth initiates JWT authentication flow for SSH + rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {} + + // WaitJWTToken waits for JWT authentication completion + rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} } @@ -161,6 +170,13 @@ message LoginRequest { // hint is used to pre-fill the email/username field during SSO authentication optional string hint = 33; + + optional bool enableSSHRoot = 34; + optional bool enableSSHSFTP = 35; + optional bool enableSSHLocalPortForwarding = 36; + optional bool enableSSHRemotePortForwarding = 37; + optional bool disableSSHAuth = 38; + optional int32 sshJWTCacheTTL = 39; } message LoginResponse { @@ -188,9 +204,9 @@ message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; - bool shouldRunProbes = 2; + bool shouldRunProbes = 2; // the UI do not using this yet, but CLIs could use it to wait until the status is ready - optional bool waitForReady = 3; + optional bool waitForReady = 3; } message StatusResponse{ @@ -255,6 +271,18 @@ message GetConfigResponse { bool disable_server_routes = 19; bool block_lan_access = 20; + + bool enableSSHRoot = 21; + + bool enableSSHSFTP = 24; + + bool enableSSHLocalPortForwarding = 22; + + bool enableSSHRemotePortForwarding = 23; + + bool disableSSHAuth = 25; + + int32 sshJWTCacheTTL = 26; } // PeerState contains the latest state of a peer @@ -276,6 +304,7 @@ message PeerState { repeated string networks = 16; google.protobuf.Duration latency = 17; string relayAddress = 18; + bytes sshHostKey = 19; } // LocalPeerState contains the latest state of the local peer @@ -317,6 +346,20 @@ message NSGroupState { string error = 4; } +// SSHSessionInfo contains information about an active SSH session +message SSHSessionInfo { + string username = 1; + string remoteAddress = 2; + string command = 3; + string jwtUsername = 4; +} + +// SSHServerState contains the latest state of the SSH server +message SSHServerState { + bool enabled = 1; + repeated SSHSessionInfo sessions = 2; +} + // FullStatus contains the full state held by the Status instance message FullStatus { ManagementState managementState = 1; @@ -330,6 +373,7 @@ message FullStatus { repeated SystemEvent events = 7; bool lazyConnectionEnabled = 9; + SSHServerState sshServerState = 10; } // Networks @@ -543,56 +587,63 @@ message SwitchProfileRequest { message SwitchProfileResponse {} message SetConfigRequest { - string username = 1; - string profileName = 2; - // managementUrl to authenticate. - string managementUrl = 3; + string username = 1; + string profileName = 2; + // managementUrl to authenticate. + string managementUrl = 3; - // adminUrl to manage keys. - string adminURL = 4; + // adminUrl to manage keys. + string adminURL = 4; - optional bool rosenpassEnabled = 5; + optional bool rosenpassEnabled = 5; - optional string interfaceName = 6; + optional string interfaceName = 6; - optional int64 wireguardPort = 7; + optional int64 wireguardPort = 7; - optional string optionalPreSharedKey = 8; + optional string optionalPreSharedKey = 8; - optional bool disableAutoConnect = 9; + optional bool disableAutoConnect = 9; - optional bool serverSSHAllowed = 10; + optional bool serverSSHAllowed = 10; - optional bool rosenpassPermissive = 11; + optional bool rosenpassPermissive = 11; - optional bool networkMonitor = 12; + optional bool networkMonitor = 12; - optional bool disable_client_routes = 13; - optional bool disable_server_routes = 14; - optional bool disable_dns = 15; - optional bool disable_firewall = 16; - optional bool block_lan_access = 17; + optional bool disable_client_routes = 13; + optional bool disable_server_routes = 14; + optional bool disable_dns = 15; + optional bool disable_firewall = 16; + optional bool block_lan_access = 17; - optional bool disable_notifications = 18; + optional bool disable_notifications = 18; - optional bool lazyConnectionEnabled = 19; + optional bool lazyConnectionEnabled = 19; - optional bool block_inbound = 20; + optional bool block_inbound = 20; - repeated string natExternalIPs = 21; - bool cleanNATExternalIPs = 22; + repeated string natExternalIPs = 21; + bool cleanNATExternalIPs = 22; - bytes customDNSAddress = 23; + bytes customDNSAddress = 23; - repeated string extraIFaceBlacklist = 24; + repeated string extraIFaceBlacklist = 24; - repeated string dns_labels = 25; - // cleanDNSLabels clean map list of DNS labels. - bool cleanDNSLabels = 26; + repeated string dns_labels = 25; + // cleanDNSLabels clean map list of DNS labels. + bool cleanDNSLabels = 26; - optional google.protobuf.Duration dnsRouteInterval = 27; + optional google.protobuf.Duration dnsRouteInterval = 27; - optional int64 mtu = 28; + optional int64 mtu = 28; + + optional bool enableSSHRoot = 29; + optional bool enableSSHSFTP = 30; + optional bool enableSSHLocalPortForwarding = 31; + optional bool enableSSHRemotePortForwarding = 32; + optional bool disableSSHAuth = 33; + optional int32 sshJWTCacheTTL = 34; } message SetConfigResponse{} @@ -644,3 +695,63 @@ message GetFeaturesResponse{ bool disable_profiles = 1; bool disable_update_settings = 2; } + +// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer +message GetPeerSSHHostKeyRequest { + // peer IP address or FQDN to get SSH host key for + string peerAddress = 1; +} + +// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer +message GetPeerSSHHostKeyResponse { + // SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname") + bytes sshHostKey = 1; + // peer IP address + string peerIP = 2; + // peer FQDN + string peerFQDN = 3; + // indicates if the SSH host key was found + bool found = 4; +} + +// RequestJWTAuthRequest for initiating JWT authentication flow +message RequestJWTAuthRequest { + // hint for OIDC login_hint parameter (typically email address) + optional string hint = 1; +} + +// RequestJWTAuthResponse contains authentication flow information +message RequestJWTAuthResponse { + // verification URI for user authentication + string verificationURI = 1; + // complete verification URI (with embedded user code) + string verificationURIComplete = 2; + // user code to enter on verification URI + string userCode = 3; + // device code for polling + string deviceCode = 4; + // expiration time in seconds + int64 expiresIn = 5; + // if a cached token is available, it will be returned here + string cachedToken = 6; + // maximum age of JWT tokens in seconds (from management server) + int64 maxTokenAge = 7; +} + +// WaitJWTTokenRequest for waiting for authentication completion +message WaitJWTTokenRequest { + // device code from RequestJWTAuthResponse + string deviceCode = 1; + // user code for verification + string userCode = 2; +} + +// WaitJWTTokenResponse contains the JWT token after authentication +message WaitJWTTokenResponse { + // JWT token (access token or ID token) + string token = 1; + // token type (e.g., "Bearer") + string tokenType = 2; + // expiration time in seconds + int64 expiresIn = 3; +} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index bf7c9c7b3..b2bf716b2 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -64,6 +64,12 @@ type DaemonServiceClient interface { // Logout disconnects from the network and deletes the peer from the management server Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) + // GetPeerSSHHostKey retrieves SSH host key for a specific peer + GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) + // RequestJWTAuth initiates JWT authentication flow for SSH + RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) + // WaitJWTToken waits for JWT authentication completion + WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) } type daemonServiceClient struct { @@ -349,6 +355,33 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe return out, nil } +func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) { + out := new(GetPeerSSHHostKeyResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) { + out := new(RequestJWTAuthResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) { + out := new(WaitJWTTokenResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -399,6 +432,12 @@ type DaemonServiceServer interface { // Logout disconnects from the network and deletes the peer from the management server Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) + // GetPeerSSHHostKey retrieves SSH host key for a specific peer + GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) + // RequestJWTAuth initiates JWT authentication flow for SSH + RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) + // WaitJWTToken waits for JWT authentication completion + WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -490,6 +529,15 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented") } +func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") +} +func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented") +} +func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -1010,6 +1058,60 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetPeerSSHHostKeyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RequestJWTAuthRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/RequestJWTAuth", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(WaitJWTTokenRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).WaitJWTToken(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/WaitJWTToken", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -1125,6 +1227,18 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetFeatures", Handler: _DaemonService_GetFeatures_Handler, }, + { + MethodName: "GetPeerSSHHostKey", + Handler: _DaemonService_GetPeerSSHHostKey_Handler, + }, + { + MethodName: "RequestJWTAuth", + Handler: _DaemonService_RequestJWTAuth_Handler, + }, + { + MethodName: "WaitJWTToken", + Handler: _DaemonService_WaitJWTToken_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/client/server/jwt_cache.go b/client/server/jwt_cache.go new file mode 100644 index 000000000..21e170517 --- /dev/null +++ b/client/server/jwt_cache.go @@ -0,0 +1,79 @@ +package server + +import ( + "sync" + "time" + + "github.com/awnumar/memguard" + log "github.com/sirupsen/logrus" +) + +type jwtCache struct { + mu sync.RWMutex + enclave *memguard.Enclave + expiresAt time.Time + timer *time.Timer + maxTokenSize int +} + +func newJWTCache() *jwtCache { + return &jwtCache{ + maxTokenSize: 8192, + } +} + +func (c *jwtCache) store(token string, maxAge time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + + c.cleanup() + + if c.timer != nil { + c.timer.Stop() + } + + tokenBytes := []byte(token) + c.enclave = memguard.NewEnclave(tokenBytes) + + c.expiresAt = time.Now().Add(maxAge) + + var timer *time.Timer + timer = time.AfterFunc(maxAge, func() { + c.mu.Lock() + defer c.mu.Unlock() + if c.timer != timer { + return + } + c.cleanup() + c.timer = nil + log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge) + }) + c.timer = timer +} + +func (c *jwtCache) get() (string, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.enclave == nil || time.Now().After(c.expiresAt) { + return "", false + } + + buffer, err := c.enclave.Open() + if err != nil { + log.Debugf("Failed to open JWT token enclave: %v", err) + return "", false + } + defer buffer.Destroy() + + token := string(buffer.Bytes()) + return token, true +} + +// cleanup destroys the secure enclave, must be called with lock held +func (c *jwtCache) cleanup() { + if c.enclave != nil { + c.enclave = nil + } + c.expiresAt = time.Time{} +} diff --git a/client/server/network.go b/client/server/network.go index 18b16795d..bb1cce56c 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -11,8 +11,8 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) type selectRoute struct { diff --git a/client/server/server.go b/client/server/server.go index 6699cdadc..a930e8a02 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -46,6 +46,9 @@ const ( defaultMaxRetryTime = 14 * 24 * time.Hour defaultRetryMultiplier = 1.7 + // JWT token cache TTL for the client daemon (disabled by default) + defaultJWTCacheTTL = 0 + errRestoreResidualState = "failed to restore residual state: %v" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled" @@ -81,6 +84,8 @@ type Server struct { profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool + + jwtCache *jwtCache } type oauthAuthFlow struct { @@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profileManager: profilemanager.NewServiceManager(configFile), profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, + jwtCache: newJWTCache(), } } @@ -373,6 +379,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.DisableNotifications = msg.DisableNotifications config.LazyConnectionEnabled = msg.LazyConnectionEnabled config.BlockInbound = msg.BlockInbound + config.EnableSSHRoot = msg.EnableSSHRoot + config.EnableSSHSFTP = msg.EnableSSHSFTP + config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding + config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding + if msg.DisableSSHAuth != nil { + config.DisableSSHAuth = msg.DisableSSHAuth + } + if msg.SshJWTCacheTTL != nil { + ttl := int(*msg.SshJWTCacheTTL) + config.SSHJWTCacheTTL = &ttl + } if msg.Mtu != nil { mtu := uint16(*msg.Mtu) @@ -493,7 +510,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro return nil, err } - if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) { + if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) { if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) { log.Debugf("using previous oauth flow info") return &proto.LoginResponse{ @@ -510,7 +527,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro } } - authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO()) + authInfo, err := oAuthFlow.RequestAuthInfo(ctx) if err != nil { log.Errorf("getting a request OAuth flow failed: %v", err) return nil, err @@ -1065,12 +1082,235 @@ func (s *Server) Status( fullStatus := s.statusRecorder.GetFullStatus() pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() + + pbFullStatus.SshServerState = s.getSSHServerState() + statusResponse.FullStatus = pbFullStatus } return &statusResponse, nil } +// getSSHServerState retrieves the current SSH server state including enabled status and active sessions +func (s *Server) getSSHServerState() *proto.SSHServerState { + s.mutex.Lock() + connectClient := s.connectClient + s.mutex.Unlock() + + if connectClient == nil { + return nil + } + + engine := connectClient.Engine() + if engine == nil { + return nil + } + + enabled, sessions := engine.GetSSHServerStatus() + sshServerState := &proto.SSHServerState{ + Enabled: enabled, + } + + for _, session := range sessions { + sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{ + Username: session.Username, + RemoteAddress: session.RemoteAddress, + Command: session.Command, + JwtUsername: session.JWTUsername, + }) + } + + return sshServerState +} + +// GetPeerSSHHostKey retrieves SSH host key for a specific peer +func (s *Server) GetPeerSSHHostKey( + ctx context.Context, + req *proto.GetPeerSSHHostKeyRequest, +) (*proto.GetPeerSSHHostKeyResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mutex.Lock() + connectClient := s.connectClient + statusRecorder := s.statusRecorder + s.mutex.Unlock() + + if connectClient == nil { + return nil, errors.New("client not initialized") + } + + engine := connectClient.Engine() + if engine == nil { + return nil, errors.New("engine not started") + } + + peerAddress := req.GetPeerAddress() + hostKey, found := engine.GetPeerSSHKey(peerAddress) + + response := &proto.GetPeerSSHHostKeyResponse{ + Found: found, + } + + if !found { + return response, nil + } + + response.SshHostKey = hostKey + + if statusRecorder == nil { + return response, nil + } + + fullStatus := statusRecorder.GetFullStatus() + for _, peerState := range fullStatus.Peers { + if peerState.IP == peerAddress || peerState.FQDN == peerAddress { + response.PeerIP = peerState.IP + response.PeerFQDN = peerState.FQDN + break + } + } + + return response, nil +} + +// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled) +func (s *Server) getJWTCacheTTL() time.Duration { + s.mutex.Lock() + config := s.config + s.mutex.Unlock() + + if config == nil || config.SSHJWTCacheTTL == nil { + return defaultJWTCacheTTL + } + + seconds := *config.SSHJWTCacheTTL + if seconds == 0 { + log.Debug("SSH JWT cache disabled (configured to 0)") + return 0 + } + + ttl := time.Duration(seconds) * time.Second + log.Debugf("SSH JWT cache TTL set to %v from config", ttl) + return ttl +} + +// RequestJWTAuth initiates JWT authentication flow for SSH +func (s *Server) RequestJWTAuth( + ctx context.Context, + msg *proto.RequestJWTAuthRequest, +) (*proto.RequestJWTAuthResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mutex.Lock() + config := s.config + s.mutex.Unlock() + + if config == nil { + return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured") + } + + jwtCacheTTL := s.getJWTCacheTTL() + if jwtCacheTTL > 0 { + if cachedToken, found := s.jwtCache.get(); found { + log.Debugf("JWT token found in cache, returning cached token for SSH authentication") + + return &proto.RequestJWTAuthResponse{ + CachedToken: cachedToken, + MaxTokenAge: int64(jwtCacheTTL.Seconds()), + }, nil + } + } + + hint := "" + if msg.Hint != nil { + hint = *msg.Hint + } + + if hint == "" { + hint = profilemanager.GetLoginHint() + } + + isDesktop := isUnixRunningDesktop() + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint) + if err != nil { + return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err) + } + + authInfo, err := oAuthFlow.RequestAuthInfo(ctx) + if err != nil { + return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err) + } + + s.mutex.Lock() + s.oauthAuthFlow.flow = oAuthFlow + s.oauthAuthFlow.info = authInfo + s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second) + s.mutex.Unlock() + + return &proto.RequestJWTAuthResponse{ + VerificationURI: authInfo.VerificationURI, + VerificationURIComplete: authInfo.VerificationURIComplete, + UserCode: authInfo.UserCode, + DeviceCode: authInfo.DeviceCode, + ExpiresIn: int64(authInfo.ExpiresIn), + MaxTokenAge: int64(jwtCacheTTL.Seconds()), + }, nil +} + +// WaitJWTToken waits for JWT authentication completion +func (s *Server) WaitJWTToken( + ctx context.Context, + req *proto.WaitJWTTokenRequest, +) (*proto.WaitJWTTokenResponse, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + s.mutex.Lock() + oAuthFlow := s.oauthAuthFlow.flow + authInfo := s.oauthAuthFlow.info + s.mutex.Unlock() + + if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode { + return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow") + } + + tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo) + if err != nil { + return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err) + } + + token := tokenInfo.GetTokenToUse() + + jwtCacheTTL := s.getJWTCacheTTL() + if jwtCacheTTL > 0 { + s.jwtCache.store(token, jwtCacheTTL) + log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL) + } else { + log.Debug("JWT caching disabled, not storing token") + } + + s.mutex.Lock() + s.oauthAuthFlow = oauthAuthFlow{} + s.mutex.Unlock() + return &proto.WaitJWTTokenResponse{ + Token: tokenInfo.GetTokenToUse(), + TokenType: tokenInfo.TokenType, + ExpiresIn: int64(tokenInfo.ExpiresIn), + }, nil +} + +func isUnixRunningDesktop() bool { + if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { + return false + } + return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" +} + func (s *Server) runProbes(waitForProbeResult bool) { if s.connectClient == nil { return @@ -1136,25 +1376,61 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p disableServerRoutes := cfg.DisableServerRoutes blockLANAccess := cfg.BlockLANAccess + enableSSHRoot := false + if cfg.EnableSSHRoot != nil { + enableSSHRoot = *cfg.EnableSSHRoot + } + + enableSSHSFTP := false + if cfg.EnableSSHSFTP != nil { + enableSSHSFTP = *cfg.EnableSSHSFTP + } + + enableSSHLocalPortForwarding := false + if cfg.EnableSSHLocalPortForwarding != nil { + enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding + } + + enableSSHRemotePortForwarding := false + if cfg.EnableSSHRemotePortForwarding != nil { + enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding + } + + disableSSHAuth := false + if cfg.DisableSSHAuth != nil { + disableSSHAuth = *cfg.DisableSSHAuth + } + + sshJWTCacheTTL := int32(0) + if cfg.SSHJWTCacheTTL != nil { + sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL) + } + return &proto.GetConfigResponse{ - ManagementUrl: managementURL.String(), - PreSharedKey: preSharedKey, - AdminURL: adminURL.String(), - InterfaceName: cfg.WgIface, - WireguardPort: int64(cfg.WgPort), - Mtu: int64(cfg.MTU), - DisableAutoConnect: cfg.DisableAutoConnect, - ServerSSHAllowed: *cfg.ServerSSHAllowed, - RosenpassEnabled: cfg.RosenpassEnabled, - RosenpassPermissive: cfg.RosenpassPermissive, - LazyConnectionEnabled: cfg.LazyConnectionEnabled, - BlockInbound: cfg.BlockInbound, - DisableNotifications: disableNotifications, - NetworkMonitor: networkMonitor, - DisableDns: disableDNS, - DisableClientRoutes: disableClientRoutes, - DisableServerRoutes: disableServerRoutes, - BlockLanAccess: blockLANAccess, + ManagementUrl: managementURL.String(), + PreSharedKey: preSharedKey, + AdminURL: adminURL.String(), + InterfaceName: cfg.WgIface, + WireguardPort: int64(cfg.WgPort), + Mtu: int64(cfg.MTU), + DisableAutoConnect: cfg.DisableAutoConnect, + ServerSSHAllowed: *cfg.ServerSSHAllowed, + RosenpassEnabled: cfg.RosenpassEnabled, + RosenpassPermissive: cfg.RosenpassPermissive, + LazyConnectionEnabled: cfg.LazyConnectionEnabled, + BlockInbound: cfg.BlockInbound, + DisableNotifications: disableNotifications, + NetworkMonitor: networkMonitor, + DisableDns: disableDNS, + DisableClientRoutes: disableClientRoutes, + DisableServerRoutes: disableServerRoutes, + BlockLanAccess: blockLANAccess, + EnableSSHRoot: enableSSHRoot, + EnableSSHSFTP: enableSSHSFTP, + EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding, + EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding, + DisableSSHAuth: disableSSHAuth, + SshJWTCacheTTL: sshJWTCacheTTL, }, nil } @@ -1385,6 +1661,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { RosenpassEnabled: peerState.RosenpassEnabled, Networks: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), + SshHostKey: peerState.SSHHostKey, } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) } diff --git a/client/server/server_test.go b/client/server/server_test.go index ae5f759ee..96d4c0af0 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -14,6 +14,7 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -316,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 1260bcc78..8e360175d 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -72,6 +72,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { lazyConnectionEnabled := true blockInbound := true mtu := int64(1280) + sshJWTCacheTTL := int32(300) req := &proto.SetConfigRequest{ ProfileName: profName, @@ -102,6 +103,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { CleanDNSLabels: false, DnsRouteInterval: durationpb.New(2 * time.Minute), Mtu: &mtu, + SshJWTCacheTTL: &sshJWTCacheTTL, } _, err = s.SetConfig(ctx, req) @@ -146,6 +148,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList()) require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval) require.Equal(t, uint16(mtu), cfg.MTU) + require.NotNil(t, cfg.SSHJWTCacheTTL) + require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL) verifyAllFieldsCovered(t, req) } @@ -167,30 +171,36 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { } expectedFields := map[string]bool{ - "ManagementUrl": true, - "AdminURL": true, - "RosenpassEnabled": true, - "RosenpassPermissive": true, - "ServerSSHAllowed": true, - "InterfaceName": true, - "WireguardPort": true, - "OptionalPreSharedKey": true, - "DisableAutoConnect": true, - "NetworkMonitor": true, - "DisableClientRoutes": true, - "DisableServerRoutes": true, - "DisableDns": true, - "DisableFirewall": true, - "BlockLanAccess": true, - "DisableNotifications": true, - "LazyConnectionEnabled": true, - "BlockInbound": true, - "NatExternalIPs": true, - "CustomDNSAddress": true, - "ExtraIFaceBlacklist": true, - "DnsLabels": true, - "DnsRouteInterval": true, - "Mtu": true, + "ManagementUrl": true, + "AdminURL": true, + "RosenpassEnabled": true, + "RosenpassPermissive": true, + "ServerSSHAllowed": true, + "InterfaceName": true, + "WireguardPort": true, + "OptionalPreSharedKey": true, + "DisableAutoConnect": true, + "NetworkMonitor": true, + "DisableClientRoutes": true, + "DisableServerRoutes": true, + "DisableDns": true, + "DisableFirewall": true, + "BlockLanAccess": true, + "DisableNotifications": true, + "LazyConnectionEnabled": true, + "BlockInbound": true, + "NatExternalIPs": true, + "CustomDNSAddress": true, + "ExtraIFaceBlacklist": true, + "DnsLabels": true, + "DnsRouteInterval": true, + "Mtu": true, + "EnableSSHRoot": true, + "EnableSSHSFTP": true, + "EnableSSHLocalPortForwarding": true, + "EnableSSHRemotePortForwarding": true, + "DisableSSHAuth": true, + "SshJWTCacheTTL": true, } val := reflect.ValueOf(req).Elem() @@ -221,29 +231,35 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) { // Map of CLI flag names to their corresponding SetConfigRequest field names. // This map must be updated when adding new config-related CLI flags. flagToField := map[string]string{ - "management-url": "ManagementUrl", - "admin-url": "AdminURL", - "enable-rosenpass": "RosenpassEnabled", - "rosenpass-permissive": "RosenpassPermissive", - "allow-server-ssh": "ServerSSHAllowed", - "interface-name": "InterfaceName", - "wireguard-port": "WireguardPort", - "preshared-key": "OptionalPreSharedKey", - "disable-auto-connect": "DisableAutoConnect", - "network-monitor": "NetworkMonitor", - "disable-client-routes": "DisableClientRoutes", - "disable-server-routes": "DisableServerRoutes", - "disable-dns": "DisableDns", - "disable-firewall": "DisableFirewall", - "block-lan-access": "BlockLanAccess", - "block-inbound": "BlockInbound", - "enable-lazy-connection": "LazyConnectionEnabled", - "external-ip-map": "NatExternalIPs", - "dns-resolver-address": "CustomDNSAddress", - "extra-iface-blacklist": "ExtraIFaceBlacklist", - "extra-dns-labels": "DnsLabels", - "dns-router-interval": "DnsRouteInterval", - "mtu": "Mtu", + "management-url": "ManagementUrl", + "admin-url": "AdminURL", + "enable-rosenpass": "RosenpassEnabled", + "rosenpass-permissive": "RosenpassPermissive", + "allow-server-ssh": "ServerSSHAllowed", + "interface-name": "InterfaceName", + "wireguard-port": "WireguardPort", + "preshared-key": "OptionalPreSharedKey", + "disable-auto-connect": "DisableAutoConnect", + "network-monitor": "NetworkMonitor", + "disable-client-routes": "DisableClientRoutes", + "disable-server-routes": "DisableServerRoutes", + "disable-dns": "DisableDns", + "disable-firewall": "DisableFirewall", + "block-lan-access": "BlockLanAccess", + "block-inbound": "BlockInbound", + "enable-lazy-connection": "LazyConnectionEnabled", + "external-ip-map": "NatExternalIPs", + "dns-resolver-address": "CustomDNSAddress", + "extra-iface-blacklist": "ExtraIFaceBlacklist", + "extra-dns-labels": "DnsLabels", + "dns-router-interval": "DnsRouteInterval", + "mtu": "Mtu", + "enable-ssh-root": "EnableSSHRoot", + "enable-ssh-sftp": "EnableSSHSFTP", + "enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding", + "enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding", + "disable-ssh-auth": "DisableSSHAuth", + "ssh-jwt-cache-ttl": "SshJWTCacheTTL", } // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). diff --git a/client/server/state_generic.go b/client/server/state_generic.go index e6c7bdd44..980ba0cda 100644 --- a/client/server/state_generic.go +++ b/client/server/state_generic.go @@ -6,9 +6,11 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/ssh/config" ) func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) + mgr.RegisterState(&config.ShutdownState{}) } diff --git a/client/server/state_linux.go b/client/server/state_linux.go index 087628907..019477d8e 100644 --- a/client/server/state_linux.go +++ b/client/server/state_linux.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/ssh/config" ) func registerStates(mgr *statemanager.Manager) { @@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&systemops.ShutdownState{}) mgr.RegisterState(&nftables.ShutdownState{}) mgr.RegisterState(&iptables.ShutdownState{}) + mgr.RegisterState(&config.ShutdownState{}) } diff --git a/client/ssh/client.go b/client/ssh/client.go deleted file mode 100644 index afba347f8..000000000 --- a/client/ssh/client.go +++ /dev/null @@ -1,118 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "net" - "os" - "time" - - "golang.org/x/crypto/ssh" - "golang.org/x/term" -) - -// Client wraps crypto/ssh Client to simplify usage -type Client struct { - client *ssh.Client -} - -// Close closes the wrapped SSH Client -func (c *Client) Close() error { - return c.client.Close() -} - -// OpenTerminal starts an interactive terminal session with the remote SSH server -func (c *Client) OpenTerminal() error { - session, err := c.client.NewSession() - if err != nil { - return fmt.Errorf("failed to open new session: %v", err) - } - defer func() { - err := session.Close() - if err != nil { - return - } - }() - - fd := int(os.Stdout.Fd()) - state, err := term.MakeRaw(fd) - if err != nil { - return fmt.Errorf("failed to run raw terminal: %s", err) - } - defer func() { - err := term.Restore(fd, state) - if err != nil { - return - } - }() - - w, h, err := term.GetSize(fd) - if err != nil { - return fmt.Errorf("terminal get size: %s", err) - } - - modes := ssh.TerminalModes{ - ssh.ECHO: 1, - ssh.TTY_OP_ISPEED: 14400, - ssh.TTY_OP_OSPEED: 14400, - } - - terminal := os.Getenv("TERM") - if terminal == "" { - terminal = "xterm-256color" - } - if err := session.RequestPty(terminal, h, w, modes); err != nil { - return fmt.Errorf("failed requesting pty session with xterm: %s", err) - } - - session.Stdout = os.Stdout - session.Stderr = os.Stderr - session.Stdin = os.Stdin - - if err := session.Shell(); err != nil { - return fmt.Errorf("failed to start login shell on the remote host: %s", err) - } - - if err := session.Wait(); err != nil { - if e, ok := err.(*ssh.ExitError); ok { - if e.ExitStatus() == 130 { - return nil - } - } - return fmt.Errorf("failed running SSH session: %s", err) - } - - return nil -} - -// DialWithKey connects to the remote SSH server with a provided private key file (PEM). -func DialWithKey(addr, user string, privateKey []byte) (*Client, error) { - - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - return nil, err - } - - config := &ssh.ClientConfig{ - User: user, - Timeout: 5 * time.Second, - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(signer), - }, - HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }), - } - - return Dial("tcp", addr, config) -} - -// Dial connects to the remote SSH server. -func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) { - client, err := ssh.Dial(network, addr, config) - if err != nil { - return nil, err - } - return &Client{ - client: client, - }, nil -} diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go new file mode 100644 index 000000000..882056374 --- /dev/null +++ b/client/ssh/client/client.go @@ -0,0 +1,699 @@ +package client + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" + "golang.org/x/term" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/detection" +) + +const ( + // DefaultDaemonAddr is the default address for the NetBird daemon + DefaultDaemonAddr = "unix:///var/run/netbird.sock" + // DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows + DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731" +) + +// Client wraps crypto/ssh Client for simplified SSH operations +type Client struct { + client *ssh.Client + terminalState *term.State + terminalFd int + + windowsStdoutMode uint32 // nolint:unused + windowsStdinMode uint32 // nolint:unused +} + +func (c *Client) Close() error { + return c.client.Close() +} + +func (c *Client) OpenTerminal(ctx context.Context) error { + session, err := c.client.NewSession() + if err != nil { + return fmt.Errorf("new session: %w", err) + } + defer func() { + if err := session.Close(); err != nil { + log.Debugf("session close error: %v", err) + } + }() + + if err := c.setupTerminalMode(ctx, session); err != nil { + return err + } + + c.setupSessionIO(session) + + if err := session.Shell(); err != nil { + return fmt.Errorf("start shell: %w", err) + } + + return c.waitForSession(ctx, session) +} + +// setupSessionIO connects session streams to local terminal +func (c *Client) setupSessionIO(session *ssh.Session) { + session.Stdout = os.Stdout + session.Stderr = os.Stderr + session.Stdin = os.Stdin +} + +// waitForSession waits for the session to complete with context cancellation +func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error { + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + defer c.restoreTerminal() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return c.handleSessionError(err) + } +} + +// handleSessionError processes session termination errors +func (c *Client) handleSessionError(err error) error { + if err == nil { + return nil + } + + var e *ssh.ExitError + var em *ssh.ExitMissingError + if !errors.As(err, &e) && !errors.As(err, &em) { + return fmt.Errorf("session wait: %w", err) + } + + return nil +} + +// restoreTerminal restores the terminal to its original state +func (c *Client) restoreTerminal() { + if c.terminalState != nil { + _ = term.Restore(c.terminalFd, c.terminalState) + c.terminalState = nil + c.terminalFd = 0 + } + + if err := c.restoreWindowsConsoleState(); err != nil { + log.Debugf("restore Windows console state: %v", err) + } +} + +// ExecuteCommand executes a command on the remote host and returns the output +func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) { + session, cleanup, err := c.createSession(ctx) + if err != nil { + return nil, err + } + defer cleanup() + + output, err := session.CombinedOutput(command) + if err != nil { + var e *ssh.ExitError + var em *ssh.ExitMissingError + if !errors.As(err, &e) && !errors.As(err, &em) { + return output, fmt.Errorf("execute command: %w", err) + } + } + + return output, nil +} + +// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal +func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error { + session, cleanup, err := c.createSession(ctx) + if err != nil { + return fmt.Errorf("create session: %w", err) + } + defer cleanup() + + c.setupSessionIO(session) + + if err := session.Start(command); err != nil { + return fmt.Errorf("start command: %w", err) + } + + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGTERM) + select { + case <-done: + return ctx.Err() + case <-time.After(100 * time.Millisecond): + return ctx.Err() + } + case err := <-done: + return c.handleCommandError(err) + } +} + +// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions +func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error { + session, cleanup, err := c.createSession(ctx) + if err != nil { + return fmt.Errorf("create session: %w", err) + } + defer cleanup() + + if err := c.setupTerminalMode(ctx, session); err != nil { + return fmt.Errorf("setup terminal mode: %w", err) + } + + c.setupSessionIO(session) + + if err := session.Start(command); err != nil { + return fmt.Errorf("start command: %w", err) + } + + defer c.restoreTerminal() + + done := make(chan error, 1) + go func() { + done <- session.Wait() + }() + + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGTERM) + select { + case <-done: + return ctx.Err() + case <-time.After(100 * time.Millisecond): + return ctx.Err() + } + case err := <-done: + return c.handleCommandError(err) + } +} + +// handleCommandError processes command execution errors +func (c *Client) handleCommandError(err error) error { + if err == nil { + return nil + } + + var e *ssh.ExitError + var em *ssh.ExitMissingError + if errors.As(err, &e) || errors.As(err, &em) { + return err + } + + return fmt.Errorf("execute command: %w", err) +} + +// setupContextCancellation sets up context cancellation for a session +func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() { + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = session.Signal(ssh.SIGTERM) + _ = session.Close() + case <-done: + } + }() + return func() { close(done) } +} + +// createSession creates a new SSH session with context cancellation setup +func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) { + session, err := c.client.NewSession() + if err != nil { + return nil, nil, fmt.Errorf("new session: %w", err) + } + + cancel := c.setupContextCancellation(ctx, session) + cleanup := func() { + cancel() + _ = session.Close() + } + + return session, cleanup, nil +} + +// getDefaultDaemonAddr returns the daemon address from environment or default for the OS +func getDefaultDaemonAddr() string { + if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" { + return addr + } + if runtime.GOOS == "windows" { + return DefaultDaemonAddrWindows + } + return DefaultDaemonAddr +} + +// DialOptions contains options for SSH connections +type DialOptions struct { + KnownHostsFile string + IdentityFile string + DaemonAddr string + SkipCachedToken bool + InsecureSkipVerify bool +} + +// Dial connects to the given ssh server with specified options +func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) { + daemonAddr := opts.DaemonAddr + if daemonAddr == "" { + daemonAddr = getDefaultDaemonAddr() + } + opts.DaemonAddr = daemonAddr + + hostKeyCallback, err := createHostKeyCallback(opts) + if err != nil { + return nil, fmt.Errorf("create host key callback: %w", err) + } + + config := &ssh.ClientConfig{ + User: user, + Timeout: 30 * time.Second, + HostKeyCallback: hostKeyCallback, + } + + if opts.IdentityFile != "" { + authMethod, err := createSSHKeyAuth(opts.IdentityFile) + if err != nil { + return nil, fmt.Errorf("create SSH key auth: %w", err) + } + config.Auth = append(config.Auth, authMethod) + } + + return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken) +} + +// dialSSH establishes an SSH connection without JWT authentication +func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) { + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, fmt.Errorf("dial %s: %w", addr, err) + } + + clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("connection close after handshake failure: %v", closeErr) + } + return nil, fmt.Errorf("ssh handshake: %w", err) + } + + client := ssh.NewClient(clientConn, chans, reqs) + return &Client{ + client: client, + }, nil +} + +// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection +func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("parse address %s: %w", addr, err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("parse port %s: %w", portStr, err) + } + + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port) + if err != nil { + return nil, fmt.Errorf("SSH server detection failed: %w", err) + } + + if !serverType.RequiresJWT() { + return dialSSH(ctx, network, addr, config) + } + + jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout) + defer cancel() + + jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache) + if err != nil { + return nil, fmt.Errorf("request JWT token: %w", err) + } + + configWithJWT := nbssh.AddJWTAuth(config, jwtToken) + return dialSSH(ctx, network, addr, configWithJWT) +} + +// requestJWTToken requests a JWT token from the NetBird daemon +func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) { + hint := profilemanager.GetLoginHint() + + conn, err := connectToDaemon(daemonAddr) + if err != nil { + return "", fmt.Errorf("connect to daemon: %w", err) + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint) +} + +// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon +func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error { + conn, err := connectToDaemon(daemonAddr) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("daemon connection close error: %v", err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + verifier := nbssh.NewDaemonHostKeyVerifier(client) + callback := nbssh.CreateHostKeyCallback(verifier) + return callback(hostname, remote, key) +} + +func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) { + addr := strings.TrimPrefix(daemonAddr, "tcp://") + + conn, err := grpc.NewClient( + addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err) + return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err) + } + + return conn, nil +} + +// getKnownHostsFiles returns paths to known_hosts files in order of preference +func getKnownHostsFiles() []string { + var files []string + + // User's known_hosts file (highest priority) + if homeDir, err := os.UserHomeDir(); err == nil { + userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts") + files = append(files, userKnownHosts) + } + + // NetBird managed known_hosts files + if runtime.GOOS == "windows" { + programData := os.Getenv("PROGRAMDATA") + if programData == "" { + programData = `C:\ProgramData` + } + netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird") + files = append(files, netbirdKnownHosts) + } else { + files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird") + files = append(files, "/etc/ssh/ssh_known_hosts") + } + + return files +} + +// createHostKeyCallback creates a host key verification callback +func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) { + if opts.InsecureSkipVerify { + return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode + } + + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil { + return nil + } + return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile) + }, nil +} + +func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error { + if daemonAddr == "" { + return fmt.Errorf("no daemon address") + } + return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr) +} + +func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error { + knownHostsFiles := getKnownHostsFilesList(knownHostsFile) + hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles) + + for _, callback := range hostKeyCallbacks { + if err := callback(hostname, remote, key); err == nil { + return nil + } + } + return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname) +} + +func getKnownHostsFilesList(knownHostsFile string) []string { + if knownHostsFile != "" { + return []string{knownHostsFile} + } + return getKnownHostsFiles() +} + +func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback { + var hostKeyCallbacks []ssh.HostKeyCallback + for _, file := range knownHostsFiles { + if callback, err := knownhosts.New(file); err == nil { + hostKeyCallbacks = append(hostKeyCallbacks, callback) + } + } + return hostKeyCallbacks +} + +// createSSHKeyAuth creates SSH key authentication from a private key file +func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) { + keyData, err := os.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err) + } + + signer, err := ssh.ParsePrivateKey(keyData) + if err != nil { + return nil, fmt.Errorf("parse SSH private key: %w", err) + } + + return ssh.PublicKeys(signer), nil +} + +// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr +func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error { + localListener, err := net.Listen("tcp", localAddr) + if err != nil { + return fmt.Errorf("listen on %s: %w", localAddr, err) + } + + go func() { + defer func() { + if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Debugf("local listener close error: %v", err) + } + }() + for { + localConn, err := localListener.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + continue + } + + go c.handleLocalForward(localConn, remoteAddr) + } + }() + + <-ctx.Done() + if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Debugf("local listener close error: %v", err) + } + return ctx.Err() +} + +// handleLocalForward handles a single local port forwarding connection +func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { + defer func() { + if err := localConn.Close(); err != nil { + log.Debugf("local connection close error: %v", err) + } + }() + + channel, err := c.client.Dial("tcp", remoteAddr) + if err != nil { + if strings.Contains(err.Error(), "administratively prohibited") { + _, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n") + } else { + log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err) + } + return + } + defer func() { + if err := channel.Close(); err != nil { + log.Debugf("remote channel close error: %v", err) + } + }() + + go func() { + if _, err := io.Copy(channel, localConn); err != nil { + log.Debugf("local forward copy error (local->remote): %v", err) + } + }() + + if _, err := io.Copy(localConn, channel); err != nil { + log.Debugf("local forward copy error (remote->local): %v", err) + } +} + +// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr +func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error { + host, port, err := c.parseRemoteAddress(remoteAddr) + if err != nil { + return fmt.Errorf("parse remote address: %w", err) + } + + req := c.buildTCPIPForwardRequest(host, port) + if err := c.sendTCPIPForwardRequest(req); err != nil { + return fmt.Errorf("setup remote forward: %w", err) + } + + go c.handleRemoteForwardChannels(ctx, localAddr) + + <-ctx.Done() + + if err := c.cancelTCPIPForwardRequest(req); err != nil { + return fmt.Errorf("cancel tcpip-forward: %w", err) + } + return ctx.Err() +} + +// parseRemoteAddress parses host and port from remote address string +func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) { + host, portStr, err := net.SplitHostPort(remoteAddr) + if err != nil { + return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err) + } + + return host, uint32(port), nil +} + +// buildTCPIPForwardRequest creates a tcpip-forward request message +func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg { + return tcpipForwardMsg{ + Host: host, + Port: port, + } +} + +// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding +func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error { + ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req)) + if err != nil { + return fmt.Errorf("send tcpip-forward request: %w", err) + } + if !ok { + return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)") + } + return nil +} + +// cancelTCPIPForwardRequest cancels the tcpip-forward request +func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error { + _, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req)) + if err != nil { + return fmt.Errorf("send cancel-tcpip-forward request: %w", err) + } + return nil +} + +// handleRemoteForwardChannels handles incoming forwarded-tcpip channels +func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) { + // Get the channel once - subsequent calls return nil! + channelRequests := c.client.HandleChannelOpen("forwarded-tcpip") + if channelRequests == nil { + log.Debugf("forwarded-tcpip channel type already being handled") + return + } + + for { + select { + case <-ctx.Done(): + return + case newChan := <-channelRequests: + if newChan != nil { + go c.handleRemoteForwardChannel(newChan, localAddr) + } + } + } +} + +// handleRemoteForwardChannel handles a single forwarded-tcpip channel +func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) { + channel, reqs, err := newChan.Accept() + if err != nil { + return + } + defer func() { + if err := channel.Close(); err != nil { + log.Debugf("remote channel close error: %v", err) + } + }() + + go ssh.DiscardRequests(reqs) + + localConn, err := net.Dial("tcp", localAddr) + if err != nil { + return + } + defer func() { + if err := localConn.Close(); err != nil { + log.Debugf("local connection close error: %v", err) + } + }() + + go func() { + if _, err := io.Copy(localConn, channel); err != nil { + log.Debugf("remote forward copy error (remote->local): %v", err) + } + }() + + if _, err := io.Copy(channel, localConn); err != nil { + log.Debugf("remote forward copy error (local->remote): %v", err) + } +} + +// tcpipForwardMsg represents the structure for tcpip-forward requests +type tcpipForwardMsg struct { + Host string + Port uint32 +} diff --git a/client/ssh/client/client_test.go b/client/ssh/client/client_test.go new file mode 100644 index 000000000..e38e02a86 --- /dev/null +++ b/client/ssh/client/client_test.go @@ -0,0 +1,512 @@ +package client + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "os/user" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + + "github.com/netbirdio/netbird/client/ssh" + sshserver "github.com/netbirdio/netbird/client/ssh/server" + "github.com/netbirdio/netbird/client/ssh/testutil" +) + +// TestMain handles package-level setup and cleanup +func TestMain(m *testing.M) { + // Guard against infinite recursion when test binary is called as "netbird ssh exec" + // This happens when running tests as non-privileged user with fallback + if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { + // Just exit with error to break the recursion + fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") + os.Exit(1) + } + + // Run tests + code := m.Run() + + // Cleanup any created test users + testutil.CleanupTestUsers() + + os.Exit(code) +} + +func TestSSHClient_DialWithKey(t *testing.T) { + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create and start server + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Test Dial + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + // Verify client is connected + assert.NotNil(t, client.client) +} + +func TestSSHClient_CommandExecution(t *testing.T) { + if runtime.GOOS == "windows" && testutil.IsCI() { + t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues") + } + + server, _, client := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + t.Run("ExecuteCommand captures output", func(t *testing.T) { + output, err := client.ExecuteCommand(ctx, "echo hello") + assert.NoError(t, err) + assert.Contains(t, string(output), "hello") + }) + + t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) { + err := client.ExecuteCommandWithIO(ctx, "echo world") + assert.NoError(t, err) + }) + + t.Run("commands with flags work", func(t *testing.T) { + output, err := client.ExecuteCommand(ctx, "echo -n test_flag") + assert.NoError(t, err) + assert.Equal(t, "test_flag", strings.TrimSpace(string(output))) + }) + + t.Run("non-zero exit codes don't return errors", func(t *testing.T) { + var testCmd string + if runtime.GOOS == "windows" { + testCmd = "echo hello | Select-String notfound" + } else { + testCmd = "echo 'hello' | grep 'notfound'" + } + _, err := client.ExecuteCommand(ctx, testCmd) + assert.NoError(t, err) + }) +} + +func TestSSHClient_ConnectionHandling(t *testing.T) { + server, serverAddr, _ := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Generate client key for multiple connections + + const numClients = 3 + clients := make([]*Client, numClients) + + currentUser := testutil.GetTestUsername(t) + for i := 0; i < numClients; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + cancel() + require.NoError(t, err, "Client %d should connect successfully", i) + clients[i] = client + } + + for i, client := range clients { + err := client.Close() + assert.NoError(t, err, "Client %d should close without error", i) + } +} + +func TestSSHClient_ContextCancellation(t *testing.T) { + server, serverAddr, _ := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + t.Run("connection with short timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + _, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + if err != nil { + // Check for actual timeout-related errors rather than string matching + assert.True(t, + errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + strings.Contains(err.Error(), "timeout"), + "Expected timeout-related error, got: %v", err) + } + }) + + t.Run("command execution cancellation", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + currentUser := testutil.GetTestUsername(t) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + if err := client.Close(); err != nil { + t.Logf("client close error: %v", err) + } + }() + + cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cmdCancel() + + err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10") + if err != nil { + var exitMissingErr *cryptossh.ExitMissingError + isValidCancellation := errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + errors.As(err, &exitMissingErr) + assert.True(t, isValidCancellation, "Should handle command cancellation properly") + } + }) +} + +func TestSSHClient_NoAuthMode(t *testing.T) { + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + + t.Run("any key succeeds in no-auth mode", func(t *testing.T) { + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + assert.NoError(t, err) + if client != nil { + require.NoError(t, client.Close(), "Client should close without error") + } + }) +} + +func TestSSHClient_TerminalState(t *testing.T) { + server, _, client := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + assert.Nil(t, client.terminalState) + assert.Equal(t, 0, client.terminalFd) + + client.restoreTerminal() + assert.Nil(t, client.terminalState) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := client.OpenTerminal(ctx) + // In test environment without a real terminal, this may complete quickly or timeout + // Both behaviors are acceptable for testing terminal state management + if err != nil { + if runtime.GOOS == "windows" { + assert.True(t, + strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "console"), + "Should timeout or have console error on Windows") + } else { + // On Unix systems in test environment, we may get various errors + // including timeouts or terminal-related errors + assert.True(t, + strings.Contains(err.Error(), "context deadline exceeded") || + strings.Contains(err.Error(), "terminal") || + strings.Contains(err.Error(), "pty"), + "Expected timeout or terminal-related error, got: %v", err) + } + } +} + +func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) { + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + currentUser := testutil.GetTestUsername(t) + client, err := Dial(ctx, serverAddr, currentUser, DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + + return server, serverAddr, client +} + +func TestSSHClient_PortForwarding(t *testing.T) { + server, _, client := setupTestSSHServerAndClient(t) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + + t.Run("local forwarding times out gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080") + assert.Error(t, err) + assert.True(t, + errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + strings.Contains(err.Error(), "connection"), + "Expected context or connection error") + }) + + t.Run("remote forwarding denied", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080") + assert.Error(t, err) + assert.True(t, + strings.Contains(err.Error(), "denied") || + strings.Contains(err.Error(), "disabled"), + "Should be denied by default") + }) + + t.Run("invalid addresses fail", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080") + assert.Error(t, err) + + err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address") + assert.Error(t, err) + }) +} + +func TestSSHClient_PortForwardingDataTransfer(t *testing.T) { + if testing.Short() { + t.Skip("Skipping data transfer test in short mode") + } + + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + serverConfig := &sshserver.Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := sshserver.New(serverConfig) + server.SetAllowLocalPortForwarding(true) + server.SetAllowRootLogin(true) // Allow root/admin login for tests + + serverAddr := sshserver.StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Port forwarding requires the actual current user, not test user + realUser, err := getRealCurrentUser() + require.NoError(t, err) + + // Skip if running as system account that can't do port forwarding + if testutil.IsSystemAccount(realUser) { + t.Skipf("Skipping port forwarding test - running as system account: %s", realUser) + } + + client, err := Dial(ctx, serverAddr, realUser, DialOptions{ + InsecureSkipVerify: true, // Skip host key verification for test + }) + require.NoError(t, err) + defer func() { + if err := client.Close(); err != nil { + t.Logf("client close error: %v", err) + } + }() + + testServer, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + if err := testServer.Close(); err != nil { + t.Logf("test server close error: %v", err) + } + }() + + testServerAddr := testServer.Addr().String() + expectedResponse := "Hello, World!" + + go func() { + for { + conn, err := testServer.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { + if err := c.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + buf := make([]byte, 1024) + if _, err := c.Read(buf); err != nil { + t.Logf("connection read error: %v", err) + return + } + if _, err := c.Write([]byte(expectedResponse)); err != nil { + t.Logf("connection write error: %v", err) + } + }(conn) + } + }() + + localListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + localAddr := localListener.Addr().String() + if err := localListener.Close(); err != nil { + t.Logf("local listener close error: %v", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go func() { + err := client.LocalPortForward(ctx, localAddr, testServerAddr) + if err != nil && !errors.Is(err, context.Canceled) { + if isWindowsPrivilegeError(err) { + t.Logf("Port forward failed due to Windows privilege restrictions: %v", err) + } else { + t.Logf("Port forward error: %v", err) + } + } + }() + + time.Sleep(100 * time.Millisecond) + + conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second) + require.NoError(t, err) + defer func() { + if err := conn.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + + _, err = conn.Write([]byte("test")) + require.NoError(t, err) + + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Logf("set read deadline error: %v", err) + } + response := make([]byte, len(expectedResponse)) + n, err := io.ReadFull(conn, response) + require.NoError(t, err) + assert.Equal(t, len(expectedResponse), n) + assert.Equal(t, expectedResponse, string(response)) +} + +// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding +func getRealCurrentUser() (string, error) { + if runtime.GOOS == "windows" { + if currentUser, err := user.Current(); err == nil { + return currentUser.Username, nil + } + } + + if username := os.Getenv("USER"); username != "" { + return username, nil + } + + if currentUser, err := user.Current(); err == nil { + return currentUser.Username, nil + } + + return "", fmt.Errorf("unable to determine current user") +} + +// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions +func isWindowsPrivilegeError(err error) bool { + if err == nil { + return false + } + + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD + strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess) + strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser) + strings.Contains(errStr, "privilege") || + strings.Contains(errStr, "access denied") || + strings.Contains(errStr, "user authentication failed") +} diff --git a/client/ssh/client/terminal_unix.go b/client/ssh/client/terminal_unix.go new file mode 100644 index 000000000..aaa3418f9 --- /dev/null +++ b/client/ssh/client/terminal_unix.go @@ -0,0 +1,127 @@ +//go:build !windows + +package client + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/term" +) + +func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error { + stdinFd := int(os.Stdin.Fd()) + + if !term.IsTerminal(stdinFd) { + return c.setupNonTerminalMode(ctx, session) + } + + fd := int(os.Stdin.Fd()) + + state, err := term.MakeRaw(fd) + if err != nil { + return c.setupNonTerminalMode(ctx, session) + } + + if err := c.setupTerminal(session, fd); err != nil { + if restoreErr := term.Restore(fd, state); restoreErr != nil { + log.Debugf("restore terminal state: %v", restoreErr) + } + return err + } + + c.terminalState = state + c.terminalFd = fd + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + + go func() { + defer signal.Stop(sigChan) + select { + case <-ctx.Done(): + if err := term.Restore(fd, state); err != nil { + log.Debugf("restore terminal state: %v", err) + } + case sig := <-sigChan: + if err := term.Restore(fd, state); err != nil { + log.Debugf("restore terminal state: %v", err) + } + signal.Reset(sig) + s, ok := sig.(syscall.Signal) + if !ok { + log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig) + return + } + if err := syscall.Kill(syscall.Getpid(), s); err != nil { + log.Debugf("kill process with signal %v: %v", s, err) + } + } + }() + + return nil +} + +func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error { + return nil +} + +// restoreWindowsConsoleState is a no-op on Unix systems +func (c *Client) restoreWindowsConsoleState() error { + return nil +} + +func (c *Client) setupTerminal(session *ssh.Session, fd int) error { + w, h, err := term.GetSize(fd) + if err != nil { + return fmt.Errorf("get terminal size: %w", err) + } + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + // Ctrl+C + ssh.VINTR: 3, + // Ctrl+\ + ssh.VQUIT: 28, + // Backspace + ssh.VERASE: 127, + // Ctrl+U + ssh.VKILL: 21, + // Ctrl+D + ssh.VEOF: 4, + ssh.VEOL: 0, + ssh.VEOL2: 0, + // Ctrl+Q + ssh.VSTART: 17, + // Ctrl+S + ssh.VSTOP: 19, + // Ctrl+Z + ssh.VSUSP: 26, + // Ctrl+O + ssh.VDISCARD: 15, + // Ctrl+R + ssh.VREPRINT: 18, + // Ctrl+W + ssh.VWERASE: 23, + // Ctrl+V + ssh.VLNEXT: 22, + } + + terminal := os.Getenv("TERM") + if terminal == "" { + terminal = "xterm-256color" + } + + if err := session.RequestPty(terminal, h, w, modes); err != nil { + return fmt.Errorf("request pty: %w", err) + } + + return nil +} diff --git a/client/ssh/client/terminal_windows.go b/client/ssh/client/terminal_windows.go new file mode 100644 index 000000000..462438317 --- /dev/null +++ b/client/ssh/client/terminal_windows.go @@ -0,0 +1,265 @@ +package client + +import ( + "context" + "errors" + "fmt" + "os" + "syscall" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" +) + +const ( + enableProcessedInput = 0x0001 + enableLineInput = 0x0002 + enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT + enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode) + enableVirtualTerminalInput = 0x0200 +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") + procSetConsoleMode = kernel32.NewProc("SetConsoleMode") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") +) + +// ConsoleUnavailableError indicates that Windows console handles are not available +// (e.g., in CI environments where stdout/stdin are redirected) +type ConsoleUnavailableError struct { + Operation string + Err error +} + +func (e *ConsoleUnavailableError) Error() string { + return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err) +} + +func (e *ConsoleUnavailableError) Unwrap() error { + return e.Err +} + +type coord struct { + x, y int16 +} + +type smallRect struct { + left, top, right, bottom int16 +} + +type consoleScreenBufferInfo struct { + size coord + cursorPosition coord + attributes uint16 + window smallRect + maximumWindowSize coord +} + +func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error { + if err := c.saveWindowsConsoleState(); err != nil { + var consoleErr *ConsoleUnavailableError + if errors.As(err, &consoleErr) { + log.Debugf("console unavailable, not requesting PTY: %v", err) + return nil + } + return fmt.Errorf("save console state: %w", err) + } + + if err := c.enableWindowsVirtualTerminal(); err != nil { + var consoleErr *ConsoleUnavailableError + if errors.As(err, &consoleErr) { + log.Debugf("virtual terminal unavailable: %v", err) + } else { + return fmt.Errorf("failed to enable virtual terminal: %w", err) + } + } + + w, h := c.getWindowsConsoleSize() + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + ssh.ICRNL: 1, + ssh.OPOST: 1, + ssh.ONLCR: 1, + ssh.ISIG: 1, + ssh.ICANON: 1, + ssh.VINTR: 3, // Ctrl+C + ssh.VQUIT: 28, // Ctrl+\ + ssh.VERASE: 127, // Backspace + ssh.VKILL: 21, // Ctrl+U + ssh.VEOF: 4, // Ctrl+D + ssh.VEOL: 0, + ssh.VEOL2: 0, + ssh.VSTART: 17, // Ctrl+Q + ssh.VSTOP: 19, // Ctrl+S + ssh.VSUSP: 26, // Ctrl+Z + ssh.VDISCARD: 15, // Ctrl+O + ssh.VWERASE: 23, // Ctrl+W + ssh.VLNEXT: 22, // Ctrl+V + ssh.VREPRINT: 18, // Ctrl+R + } + + if err := session.RequestPty("xterm-256color", h, w, modes); err != nil { + if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil { + log.Debugf("restore Windows console state: %v", restoreErr) + } + return fmt.Errorf("request pty: %w", err) + } + + return nil +} + +func (c *Client) saveWindowsConsoleState() error { + defer func() { + if r := recover(); r != nil { + log.Debugf("panic in saveWindowsConsoleState: %v", r) + } + }() + + stdout := syscall.Handle(os.Stdout.Fd()) + stdin := syscall.Handle(os.Stdin.Fd()) + + var stdoutMode, stdinMode uint32 + + ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode))) + if ret == 0 { + log.Debugf("failed to get stdout console mode: %v", err) + return &ConsoleUnavailableError{ + Operation: "get stdout console mode", + Err: err, + } + } + + ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode))) + if ret == 0 { + log.Debugf("failed to get stdin console mode: %v", err) + return &ConsoleUnavailableError{ + Operation: "get stdin console mode", + Err: err, + } + } + + c.terminalFd = 1 + c.windowsStdoutMode = stdoutMode + c.windowsStdinMode = stdinMode + + log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode) + return nil +} + +func (c *Client) enableWindowsVirtualTerminal() (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r) + } + }() + + stdout := syscall.Handle(os.Stdout.Fd()) + stdin := syscall.Handle(os.Stdin.Fd()) + var mode uint32 + + ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode))) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "get stdout console mode for VT", + Err: winErr, + } + } + + mode |= enableVirtualTerminalProcessing + ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode)) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "enable virtual terminal processing", + Err: winErr, + } + } + + ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode))) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "get stdin console mode for VT", + Err: winErr, + } + } + + mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput) + mode |= enableVirtualTerminalInput + ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode)) + if ret == 0 { + return &ConsoleUnavailableError{ + Operation: "set stdin raw mode", + Err: winErr, + } + } + + log.Debugf("enabled Windows virtual terminal processing") + return nil +} + +func (c *Client) getWindowsConsoleSize() (int, int) { + defer func() { + if r := recover(); r != nil { + log.Debugf("panic in getWindowsConsoleSize: %v", r) + } + }() + + stdout := syscall.Handle(os.Stdout.Fd()) + var csbi consoleScreenBufferInfo + + ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi))) + if ret == 0 { + log.Debugf("failed to get console buffer info, using defaults: %v", err) + return 80, 24 + } + + width := int(csbi.window.right - csbi.window.left + 1) + height := int(csbi.window.bottom - csbi.window.top + 1) + + log.Debugf("Windows console size: %dx%d", width, height) + return width, height +} + +func (c *Client) restoreWindowsConsoleState() error { + var err error + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r) + } + }() + + if c.terminalFd != 1 { + return nil + } + + stdout := syscall.Handle(os.Stdout.Fd()) + stdin := syscall.Handle(os.Stdin.Fd()) + + ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode)) + if ret == 0 { + log.Debugf("failed to restore stdout console mode: %v", winErr) + if err == nil { + err = fmt.Errorf("restore stdout console mode: %w", winErr) + } + } + + ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode)) + if ret == 0 { + log.Debugf("failed to restore stdin console mode: %v", winErr) + if err == nil { + err = fmt.Errorf("restore stdin console mode: %w", winErr) + } + } + + c.terminalFd = 0 + c.windowsStdoutMode = 0 + c.windowsStdinMode = 0 + + log.Debugf("restored Windows console state") + return err +} diff --git a/client/ssh/common.go b/client/ssh/common.go new file mode 100644 index 000000000..3beb12806 --- /dev/null +++ b/client/ssh/common.go @@ -0,0 +1,171 @@ +package ssh + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + + "github.com/netbirdio/netbird/client/proto" +) + +const ( + NetBirdSSHConfigFile = "99-netbird.conf" + + UnixSSHConfigDir = "/etc/ssh/ssh_config.d" + WindowsSSHConfigDir = "ssh/ssh_config.d" +) + +var ( + // ErrPeerNotFound indicates the peer was not found in the network + ErrPeerNotFound = errors.New("peer not found in network") + // ErrNoStoredKey indicates the peer has no stored SSH host key + ErrNoStoredKey = errors.New("peer has no stored SSH host key") +) + +// HostKeyVerifier provides SSH host key verification +type HostKeyVerifier interface { + VerifySSHHostKey(peerAddress string, key []byte) error +} + +// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon +type DaemonHostKeyVerifier struct { + client proto.DaemonServiceClient +} + +// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier +func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier { + return &DaemonHostKeyVerifier{ + client: client, + } +} + +// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon +func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{ + PeerAddress: peerAddress, + }) + if err != nil { + return err + } + + if !response.GetFound() { + return ErrPeerNotFound + } + + storedKeyData := response.GetSshHostKey() + + return VerifyHostKey(storedKeyData, presentedKey, peerAddress) +} + +// RequestJWTToken requests or retrieves a JWT token for SSH authentication +func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) { + req := &proto.RequestJWTAuthRequest{} + if hint != "" { + req.Hint = &hint + } + authResponse, err := client.RequestJWTAuth(ctx, req) + if err != nil { + return "", fmt.Errorf("request JWT auth: %w", err) + } + + if useCache && authResponse.CachedToken != "" { + log.Debug("Using cached authentication token") + return authResponse.CachedToken, nil + } + + if stderr != nil { + _, _ = fmt.Fprintln(stderr, "SSH authentication required.") + _, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete) + if authResponse.UserCode != "" { + _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) + } + _, _ = fmt.Fprintln(stderr, "Waiting for authentication...") + } + + tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{ + DeviceCode: authResponse.DeviceCode, + UserCode: authResponse.UserCode, + }) + if err != nil { + return "", fmt.Errorf("wait for JWT token: %w", err) + } + + if stdout != nil { + _, _ = fmt.Fprintln(stdout, "Authentication successful!") + } + return tokenResponse.Token, nil +} + +// VerifyHostKey verifies an SSH host key against stored peer key data. +// Returns nil only if the presented key matches the stored key. +// Returns ErrNoStoredKey if storedKeyData is empty. +// Returns an error if the keys don't match or if parsing fails. +func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error { + if len(storedKeyData) == 0 { + return ErrNoStoredKey + } + + storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData) + if err != nil { + return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err) + } + + if !bytes.Equal(presentedKey, storedPubKey.Marshal()) { + return fmt.Errorf("SSH host key mismatch for %s", peerAddress) + } + + return nil +} + +// AddJWTAuth prepends JWT password authentication to existing auth methods. +// This ensures JWT auth is tried first while preserving any existing auth methods. +func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig { + configWithJWT := *config + configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...) + return &configWithJWT +} + +// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier. +// It tries multiple addresses (hostname, IP) for the peer before failing. +func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + addresses := buildAddressList(hostname, remote) + presentedKey := key.Marshal() + + for _, addr := range addresses { + if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil { + if errors.Is(err, ErrPeerNotFound) { + // Try other addresses for this peer + continue + } + return err + } + // Verified + return nil + } + + return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname) + } +} + +// buildAddressList creates a list of addresses to check for host key verification. +// It includes the original hostname and extracts the host part from the remote address if different. +func buildAddressList(hostname string, remote net.Addr) []string { + addresses := []string{hostname} + if host, _, err := net.SplitHostPort(remote.String()); err == nil { + if host != hostname { + addresses = append(addresses, host) + } + } + return addresses +} diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go new file mode 100644 index 000000000..03a136de3 --- /dev/null +++ b/client/ssh/config/manager.go @@ -0,0 +1,282 @@ +package config + +import ( + "context" + "fmt" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" + + nbssh "github.com/netbirdio/netbird/client/ssh" +) + +const ( + EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG" + + EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG" + + MaxPeersForSSHConfig = 200 + + fileWriteTimeout = 2 * time.Second +) + +func isSSHConfigDisabled() bool { + value := os.Getenv(EnvDisableSSHConfig) + if value == "" { + return false + } + + disabled, err := strconv.ParseBool(value) + if err != nil { + return true + } + return disabled +} + +func isSSHConfigForced() bool { + value := os.Getenv(EnvForceSSHConfig) + if value == "" { + return false + } + + forced, err := strconv.ParseBool(value) + if err != nil { + return true + } + return forced +} + +// shouldGenerateSSHConfig checks if SSH config should be generated based on peer count +func shouldGenerateSSHConfig(peerCount int) bool { + if isSSHConfigDisabled() { + return false + } + + if isSSHConfigForced() { + return true + } + + return peerCount <= MaxPeersForSSHConfig +} + +// writeFileWithTimeout writes data to a file with a timeout +func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error { + ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- os.WriteFile(filename, data, perm) + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename) + } +} + +// Manager handles SSH client configuration for NetBird peers +type Manager struct { + sshConfigDir string + sshConfigFile string +} + +// PeerSSHInfo represents a peer's SSH configuration information +type PeerSSHInfo struct { + Hostname string + IP string + FQDN string +} + +// New creates a new SSH config manager +func New() *Manager { + sshConfigDir := getSystemSSHConfigDir() + return &Manager{ + sshConfigDir: sshConfigDir, + sshConfigFile: nbssh.NetBirdSSHConfigFile, + } +} + +// getSystemSSHConfigDir returns platform-specific SSH configuration directory +func getSystemSSHConfigDir() string { + if runtime.GOOS == "windows" { + return getWindowsSSHConfigDir() + } + return nbssh.UnixSSHConfigDir +} + +func getWindowsSSHConfigDir() string { + programData := os.Getenv("PROGRAMDATA") + if programData == "" { + programData = `C:\ProgramData` + } + return filepath.Join(programData, nbssh.WindowsSSHConfigDir) +} + +// SetupSSHClientConfig creates SSH client configuration for NetBird peers +func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error { + if !shouldGenerateSSHConfig(len(peers)) { + m.logSkipReason(len(peers)) + return nil + } + + sshConfig, err := m.buildSSHConfig(peers) + if err != nil { + return fmt.Errorf("build SSH config: %w", err) + } + return m.writeSSHConfig(sshConfig) +} + +func (m *Manager) logSkipReason(peerCount int) { + if isSSHConfigDisabled() { + log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig) + } else { + log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.", + peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig) + } +} + +func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) { + sshConfig := m.buildConfigHeader() + + var allHostPatterns []string + for _, peer := range peers { + hostPatterns := m.buildHostPatterns(peer) + allHostPatterns = append(allHostPatterns, hostPatterns...) + } + + if len(allHostPatterns) > 0 { + peerConfig, err := m.buildPeerConfig(allHostPatterns) + if err != nil { + return "", err + } + sshConfig += peerConfig + } + + return sshConfig, nil +} + +func (m *Manager) buildConfigHeader() string { + return "# NetBird SSH client configuration\n" + + "# Generated automatically - do not edit manually\n" + + "#\n" + + "# To disable SSH config management, use:\n" + + "# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" + + "#\n\n" +} + +func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { + uniquePatterns := make(map[string]bool) + var deduplicatedPatterns []string + for _, pattern := range allHostPatterns { + if !uniquePatterns[pattern] { + uniquePatterns[pattern] = true + deduplicatedPatterns = append(deduplicatedPatterns, pattern) + } + } + + execPath, err := m.getNetBirdExecutablePath() + if err != nil { + return "", fmt.Errorf("get NetBird executable path: %w", err) + } + + hostLine := strings.Join(deduplicatedPatterns, " ") + config := fmt.Sprintf("Host %s\n", hostLine) + + if runtime.GOOS == "windows" { + config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) + } else { + config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath) + } + config += " PreferredAuthentications password,publickey,keyboard-interactive\n" + config += " PasswordAuthentication yes\n" + config += " PubkeyAuthentication yes\n" + config += " BatchMode no\n" + config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) + config += " StrictHostKeyChecking no\n" + + if runtime.GOOS == "windows" { + config += " UserKnownHostsFile NUL\n" + } else { + config += " UserKnownHostsFile /dev/null\n" + } + + config += " CheckHostIP no\n" + config += " LogLevel ERROR\n\n" + + return config, nil +} + +func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { + var hostPatterns []string + if peer.IP != "" { + hostPatterns = append(hostPatterns, peer.IP) + } + if peer.FQDN != "" { + hostPatterns = append(hostPatterns, peer.FQDN) + } + if peer.Hostname != "" && peer.Hostname != peer.FQDN { + hostPatterns = append(hostPatterns, peer.Hostname) + } + return hostPatterns +} + +func (m *Manager) writeSSHConfig(sshConfig string) error { + sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) + + if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { + return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err) + } + + if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil { + return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err) + } + + log.Infof("Created NetBird SSH client config: %s", sshConfigPath) + return nil +} + +// RemoveSSHClientConfig removes NetBird SSH configuration +func (m *Manager) RemoveSSHClientConfig() error { + sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) + err := os.Remove(sshConfigPath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err) + } + if err == nil { + log.Infof("Removed NetBird SSH config: %s", sshConfigPath) + } + return nil +} + +func (m *Manager) getNetBirdExecutablePath() (string, error) { + execPath, err := os.Executable() + if err != nil { + return "", fmt.Errorf("retrieve executable path: %w", err) + } + + realPath, err := filepath.EvalSymlinks(execPath) + if err != nil { + log.Debugf("symlink resolution failed: %v", err) + return execPath, nil + } + + return realPath, nil +} + +// GetSSHConfigDir returns the SSH config directory path +func (m *Manager) GetSSHConfigDir() string { + return m.sshConfigDir +} + +// GetSSHConfigFile returns the SSH config file name +func (m *Manager) GetSSHConfigFile() string { + return m.sshConfigFile +} diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go new file mode 100644 index 000000000..dc3ad95b3 --- /dev/null +++ b/client/ssh/config/manager_test.go @@ -0,0 +1,159 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestManager_SetupSSHClientConfig(t *testing.T) { + // Create temporary directory for test + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + // Override manager paths to use temp directory + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + // Test SSH config generation with peers + peers := []PeerSSHInfo{ + { + Hostname: "peer1", + IP: "100.125.1.1", + FQDN: "peer1.nb.internal", + }, + { + Hostname: "peer2", + IP: "100.125.1.2", + FQDN: "peer2.nb.internal", + }, + } + + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + // Read generated config + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + content, err := os.ReadFile(configPath) + require.NoError(t, err) + + configStr := string(content) + + // Verify the basic SSH config structure exists + assert.Contains(t, configStr, "# NetBird SSH client configuration") + assert.Contains(t, configStr, "Generated automatically - do not edit manually") + + // Check that peer hostnames are included + assert.Contains(t, configStr, "100.125.1.1") + assert.Contains(t, configStr, "100.125.1.2") + assert.Contains(t, configStr, "peer1.nb.internal") + assert.Contains(t, configStr, "peer2.nb.internal") + + // Check platform-specific UserKnownHostsFile + if runtime.GOOS == "windows" { + assert.Contains(t, configStr, "UserKnownHostsFile NUL") + } else { + assert.Contains(t, configStr, "UserKnownHostsFile /dev/null") + } +} + +func TestGetSystemSSHConfigDir(t *testing.T) { + configDir := getSystemSSHConfigDir() + + // Path should not be empty + assert.NotEmpty(t, configDir) + + // Should be an absolute path + assert.True(t, filepath.IsAbs(configDir)) + + // On Unix systems, should start with /etc + // On Windows, should contain ProgramData + if runtime.GOOS == "windows" { + assert.Contains(t, strings.ToLower(configDir), "programdata") + } else { + assert.Contains(t, configDir, "/etc/ssh") + } +} + +func TestManager_PeerLimit(t *testing.T) { + // Create temporary directory for test + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + // Override manager paths to use temp directory + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + // Generate many peers (more than limit) + var peers []PeerSSHInfo + for i := 0; i < MaxPeersForSSHConfig+10; i++ { + peers = append(peers, PeerSSHInfo{ + Hostname: fmt.Sprintf("peer%d", i), + IP: fmt.Sprintf("100.125.1.%d", i%254+1), + FQDN: fmt.Sprintf("peer%d.nb.internal", i), + }) + } + + // Test that SSH config generation is skipped when too many peers + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + // Config should not be created due to peer limit + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + _, err = os.Stat(configPath) + assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers") +} + +func TestManager_ForcedSSHConfig(t *testing.T) { + // Set force environment variable + t.Setenv(EnvForceSSHConfig, "true") + + // Create temporary directory for test + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + // Override manager paths to use temp directory + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + // Generate many peers (more than limit) + var peers []PeerSSHInfo + for i := 0; i < MaxPeersForSSHConfig+10; i++ { + peers = append(peers, PeerSSHInfo{ + Hostname: fmt.Sprintf("peer%d", i), + IP: fmt.Sprintf("100.125.1.%d", i%254+1), + FQDN: fmt.Sprintf("peer%d.nb.internal", i), + }) + } + + // Test that SSH config generation is forced despite many peers + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + // Config should be created despite peer limit due to force flag + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + _, err = os.Stat(configPath) + require.NoError(t, err, "SSH config should be created when forced") + + // Verify config contains peer hostnames + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + assert.Contains(t, configStr, "peer0.nb.internal") + assert.Contains(t, configStr, "peer1.nb.internal") +} diff --git a/client/ssh/config/shutdown_state.go b/client/ssh/config/shutdown_state.go new file mode 100644 index 000000000..22f0e0678 --- /dev/null +++ b/client/ssh/config/shutdown_state.go @@ -0,0 +1,22 @@ +package config + +// ShutdownState represents SSH configuration state that needs to be cleaned up. +type ShutdownState struct { + SSHConfigDir string + SSHConfigFile string +} + +// Name returns the state name for the state manager. +func (s *ShutdownState) Name() string { + return "ssh_config_state" +} + +// Cleanup removes SSH client configuration files. +func (s *ShutdownState) Cleanup() error { + manager := &Manager{ + sshConfigDir: s.SSHConfigDir, + sshConfigFile: s.SSHConfigFile, + } + + return manager.RemoveSSHClientConfig() +} diff --git a/client/ssh/detection/detection.go b/client/ssh/detection/detection.go new file mode 100644 index 000000000..487f4665a --- /dev/null +++ b/client/ssh/detection/detection.go @@ -0,0 +1,99 @@ +package detection + +import ( + "bufio" + "context" + "net" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // ServerIdentifier is the base response for NetBird SSH servers + ServerIdentifier = "NetBird-SSH-Server" + // ProxyIdentifier is the base response for NetBird SSH proxy + ProxyIdentifier = "NetBird-SSH-Proxy" + // JWTRequiredMarker is appended to responses when JWT is required + JWTRequiredMarker = "NetBird-JWT-Required" + + // Timeout is the timeout for SSH server detection + Timeout = 5 * time.Second +) + +type ServerType string + +const ( + ServerTypeNetBirdJWT ServerType = "netbird-jwt" + ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt" + ServerTypeRegular ServerType = "regular" +) + +// Dialer provides network connection capabilities +type Dialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// RequiresJWT checks if the server type requires JWT authentication +func (s ServerType) RequiresJWT() bool { + return s == ServerTypeNetBirdJWT +} + +// ExitCode returns the exit code for the detect command +func (s ServerType) ExitCode() int { + switch s { + case ServerTypeNetBirdJWT: + return 0 + case ServerTypeNetBirdNoJWT: + return 1 + case ServerTypeRegular: + return 2 + default: + return 2 + } +} + +// DetectSSHServerType detects SSH server type using the provided dialer +func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) { + targetAddr := net.JoinHostPort(host, strconv.Itoa(port)) + + conn, err := dialer.DialContext(ctx, "tcp", targetAddr) + if err != nil { + log.Debugf("SSH connection failed for detection: %v", err) + return ServerTypeRegular, nil + } + defer conn.Close() + + if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil { + log.Debugf("set read deadline: %v", err) + return ServerTypeRegular, nil + } + + reader := bufio.NewReader(conn) + serverBanner, err := reader.ReadString('\n') + if err != nil { + log.Debugf("read SSH banner: %v", err) + return ServerTypeRegular, nil + } + + serverBanner = strings.TrimSpace(serverBanner) + log.Debugf("SSH server banner: %s", serverBanner) + + if !strings.HasPrefix(serverBanner, "SSH-") { + log.Debugf("Invalid SSH banner") + return ServerTypeRegular, nil + } + + if !strings.Contains(serverBanner, ServerIdentifier) { + log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier) + return ServerTypeRegular, nil + } + + if strings.Contains(serverBanner, JWTRequiredMarker) { + return ServerTypeNetBirdJWT, nil + } + + return ServerTypeNetBirdNoJWT, nil +} diff --git a/client/ssh/login.go b/client/ssh/login.go deleted file mode 100644 index cb2615e55..000000000 --- a/client/ssh/login.go +++ /dev/null @@ -1,53 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "net" - "net/netip" - "os" - "os/exec" - "runtime" - - "github.com/netbirdio/netbird/util" -) - -func isRoot() bool { - return os.Geteuid() == 0 -} - -func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) { - if !isRoot() { - shell := getUserShell(user) - if shell == "" { - shell = "/bin/sh" - } - - return shell, []string{"-l"}, nil - } - - loginPath, err = exec.LookPath("login") - if err != nil { - return "", nil, err - } - - addrPort, err := netip.ParseAddrPort(remoteAddr.String()) - if err != nil { - return "", nil, err - } - - switch runtime.GOOS { - case "linux": - if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") { - return loginPath, []string{"-f", user, "-p"}, nil - } - return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil - case "darwin": - return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil - case "freebsd": - return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil - default: - return "", nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS) - } -} diff --git a/client/ssh/lookup.go b/client/ssh/lookup.go deleted file mode 100644 index 9a7f6ff2e..000000000 --- a/client/ssh/lookup.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !darwin -// +build !darwin - -package ssh - -import "os/user" - -func userNameLookup(username string) (*user.User, error) { - if username == "" || (username == "root" && !isRoot()) { - return user.Current() - } - - return user.Lookup(username) -} diff --git a/client/ssh/lookup_darwin.go b/client/ssh/lookup_darwin.go deleted file mode 100644 index 913d049dc..000000000 --- a/client/ssh/lookup_darwin.go +++ /dev/null @@ -1,51 +0,0 @@ -//go:build darwin -// +build darwin - -package ssh - -import ( - "bytes" - "fmt" - "os/exec" - "os/user" - "strings" -) - -func userNameLookup(username string) (*user.User, error) { - if username == "" || (username == "root" && !isRoot()) { - return user.Current() - } - - var userObject *user.User - userObject, err := user.Lookup(username) - if err != nil && err.Error() == user.UnknownUserError(username).Error() { - return idUserNameLookup(username) - } else if err != nil { - return nil, err - } - - return userObject, nil -} - -func idUserNameLookup(username string) (*user.User, error) { - cmd := exec.Command("id", "-P", username) - out, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("error while retrieving user with id -P command, error: %v", err) - } - colon := ":" - - if !bytes.Contains(out, []byte(username+colon)) { - return nil, fmt.Errorf("unable to find user in returned string") - } - // netbird:********:501:20::0:0:netbird:/Users/netbird:/bin/zsh - parts := strings.SplitN(string(out), colon, 10) - userObject := &user.User{ - Username: parts[0], - Uid: parts[2], - Gid: parts[3], - Name: parts[7], - HomeDir: parts[8], - } - return userObject, nil -} diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go new file mode 100644 index 000000000..bc8a84b89 --- /dev/null +++ b/client/ssh/proxy/proxy.go @@ -0,0 +1,392 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/version" +) + +const ( + // sshConnectionTimeout is the timeout for SSH TCP connection establishment + sshConnectionTimeout = 120 * time.Second + // sshHandshakeTimeout is the timeout for SSH handshake completion + sshHandshakeTimeout = 30 * time.Second + + jwtAuthErrorMsg = "JWT authentication: %w" +) + +type SSHProxy struct { + daemonAddr string + targetHost string + targetPort int + stderr io.Writer + conn *grpc.ClientConn + daemonClient proto.DaemonServiceClient +} + +func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) { + grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://") + grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, fmt.Errorf("connect to daemon: %w", err) + } + + return &SSHProxy{ + daemonAddr: daemonAddr, + targetHost: targetHost, + targetPort: targetPort, + stderr: stderr, + conn: grpcConn, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + }, nil +} + +func (p *SSHProxy) Close() error { + if p.conn != nil { + return p.conn.Close() + } + return nil +} + +func (p *SSHProxy) Connect(ctx context.Context) error { + hint := profilemanager.GetLoginHint() + + jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint) + if err != nil { + return fmt.Errorf(jwtAuthErrorMsg, err) + } + + return p.runProxySSHServer(ctx, jwtToken) +} + +func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error { + serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion()) + + sshServer := &ssh.Server{ + Handler: func(s ssh.Session) { + p.handleSSHSession(ctx, s, jwtToken) + }, + ChannelHandlers: map[string]ssh.ChannelHandler{ + "session": ssh.DefaultSessionHandler, + "direct-tcpip": p.directTCPIPHandler, + }, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": func(s ssh.Session) { + p.sftpSubsystemHandler(s, jwtToken) + }, + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": p.tcpipForwardHandler, + "cancel-tcpip-forward": p.cancelTcpipForwardHandler, + }, + Version: serverVersion, + } + + hostKey, err := generateHostKey() + if err != nil { + return fmt.Errorf("generate host key: %w", err) + } + sshServer.HostSigners = []ssh.Signer{hostKey} + + conn := &stdioConn{ + stdin: os.Stdin, + stdout: os.Stdout, + } + + sshServer.HandleConn(conn) + + return nil +} + +func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) { + targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) + + sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken) + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err) + return + } + defer func() { _ = sshClient.Close() }() + + serverSession, err := sshClient.NewSession() + if err != nil { + _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err) + return + } + defer func() { _ = serverSession.Close() }() + + serverSession.Stdin = session + serverSession.Stdout = session + serverSession.Stderr = session.Stderr() + + ptyReq, winCh, isPty := session.Pty() + if isPty { + if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil { + log.Debugf("PTY request to backend: %v", err) + } + + go func() { + for win := range winCh { + if err := serverSession.WindowChange(win.Height, win.Width); err != nil { + log.Debugf("window change: %v", err) + } + } + }() + } + + if len(session.Command()) > 0 { + if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { + log.Debugf("run command: %v", err) + p.handleProxyExitCode(session, err) + } + return + } + + if err = serverSession.Shell(); err != nil { + log.Debugf("start shell: %v", err) + return + } + if err := serverSession.Wait(); err != nil { + log.Debugf("session wait: %v", err) + p.handleProxyExitCode(session, err) + } +} + +func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) { + var exitErr *cryptossh.ExitError + if errors.As(err, &exitErr) { + if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil { + log.Debugf("set exit status: %v", exitErr) + } + } +} + +func generateHostKey() (ssh.Signer, error) { + keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + if err != nil { + return nil, fmt.Errorf("generate ED25519 key: %w", err) + } + + signer, err := cryptossh.ParsePrivateKey(keyPEM) + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + + return signer, nil +} + +type stdioConn struct { + stdin io.Reader + stdout io.Writer + closed bool + mu sync.Mutex +} + +func (c *stdioConn) Read(b []byte) (n int, err error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return 0, io.EOF + } + c.mu.Unlock() + return c.stdin.Read(b) +} + +func (c *stdioConn) Write(b []byte) (n int, err error) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return 0, io.ErrClosedPipe + } + c.mu.Unlock() + return c.stdout.Write(b) +} + +func (c *stdioConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *stdioConn) LocalAddr() net.Addr { + return &net.UnixAddr{Name: "stdio", Net: "unix"} +} + +func (c *stdioConn) RemoteAddr() net.Addr { + return &net.UnixAddr{Name: "stdio", Net: "unix"} +} + +func (c *stdioConn) SetDeadline(_ time.Time) error { + return nil +} + +func (c *stdioConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (c *stdioConn) SetWriteDeadline(_ time.Time) error { + return nil +} + +func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) { + _ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy") +} + +func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { + ctx, cancel := context.WithCancel(s.Context()) + defer cancel() + + targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort)) + + sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken) + if err != nil { + _, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err) + _ = s.Exit(1) + return + } + defer func() { + if err := sshClient.Close(); err != nil { + log.Debugf("close SSH client: %v", err) + } + }() + + serverSession, err := sshClient.NewSession() + if err != nil { + _, _ = fmt.Fprintf(s, "create server session: %v\n", err) + _ = s.Exit(1) + return + } + defer func() { + if err := serverSession.Close(); err != nil { + log.Debugf("close server session: %v", err) + } + }() + + stdin, stdout, err := p.setupSFTPPipes(serverSession) + if err != nil { + log.Debugf("setup SFTP pipes: %v", err) + _ = s.Exit(1) + return + } + + if err := serverSession.RequestSubsystem("sftp"); err != nil { + _, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err) + _ = s.Exit(1) + return + } + + p.runSFTPBridge(ctx, s, stdin, stdout, serverSession) +} + +func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) { + stdin, err := serverSession.StdinPipe() + if err != nil { + return nil, nil, fmt.Errorf("get stdin pipe: %w", err) + } + + stdout, err := serverSession.StdoutPipe() + if err != nil { + return nil, nil, fmt.Errorf("get stdout pipe: %w", err) + } + + return stdin, stdout, nil +} + +func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) { + copyErrCh := make(chan error, 2) + + go func() { + _, err := io.Copy(stdin, s) + if err != nil { + log.Debugf("SFTP client to server copy: %v", err) + } + if err := stdin.Close(); err != nil { + log.Debugf("close stdin: %v", err) + } + copyErrCh <- err + }() + + go func() { + _, err := io.Copy(s, stdout) + if err != nil { + log.Debugf("SFTP server to client copy: %v", err) + } + copyErrCh <- err + }() + + go func() { + <-ctx.Done() + if err := serverSession.Close(); err != nil { + log.Debugf("force close server session on context cancellation: %v", err) + } + }() + + for i := 0; i < 2; i++ { + if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) { + log.Debugf("SFTP copy error: %v", err) + } + } + + if err := serverSession.Wait(); err != nil { + log.Debugf("SFTP session ended: %v", err) + } +} + +func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { + return false, []byte("port forwarding not supported in proxy") +} + +func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) { + return true, nil +} + +func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { + config := &cryptossh.ClientConfig{ + User: user, + Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)}, + Timeout: sshHandshakeTimeout, + HostKeyCallback: p.verifyHostKey, + } + + dialer := &net.Dialer{ + Timeout: sshConnectionTimeout, + } + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("connect to server: %w", err) + } + + clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config) + if err != nil { + _ = conn.Close() + return nil, fmt.Errorf("SSH handshake: %w", err) + } + + return cryptossh.NewClient(clientConn, chans, reqs), nil +} + +func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error { + verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient) + callback := nbssh.CreateHostKeyCallback(verifier) + return callback(hostname, remote, key) +} diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go new file mode 100644 index 000000000..c5036da37 --- /dev/null +++ b/client/ssh/proxy/proxy_test.go @@ -0,0 +1,367 @@ +package proxy + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "os" + "runtime" + "strconv" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/netbirdio/netbird/client/proto" + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/server" + "github.com/netbirdio/netbird/client/ssh/testutil" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" +) + +func TestMain(m *testing.M) { + if len(os.Args) > 2 && os.Args[1] == "ssh" { + if os.Args[2] == "exec" { + if len(os.Args) > 3 { + cmd := os.Args[3] + if cmd == "echo" && len(os.Args) > 4 { + fmt.Fprintln(os.Stdout, os.Args[4]) + os.Exit(0) + } + } + fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args) + os.Exit(1) + } + } + + code := m.Run() + + testutil.CleanupTestUsers() + + os.Exit(code) +} + +func TestSSHProxy_verifyHostKey(t *testing.T) { + t.Run("calls daemon to verify host key", func(t *testing.T) { + mockDaemon := startMockDaemon(t) + defer mockDaemon.stop() + + grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer func() { _ = grpcConn.Close() }() + + proxy := &SSHProxy{ + daemonAddr: mockDaemon.addr, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + } + + testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + testPubKey, err := nbssh.GeneratePublicKey(testKey) + require.NoError(t, err) + + mockDaemon.setHostKey("test-host", testPubKey) + + err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey)) + assert.NoError(t, err) + }) + + t.Run("rejects unknown host key", func(t *testing.T) { + mockDaemon := startMockDaemon(t) + defer mockDaemon.stop() + + grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer func() { _ = grpcConn.Close() }() + + proxy := &SSHProxy{ + daemonAddr: mockDaemon.addr, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + } + + unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey) + require.NoError(t, err) + + err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "peer unknown-host not found in network") + }) +} + +func TestSSHProxy_Connect(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // TODO: Windows test times out - user switching and command execution tested on Linux + if runtime.GOOS == "windows" { + t.Skip("Skipping on Windows - covered by Linux tests") + } + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + hostPubKey, err := nbssh.GeneratePublicKey(hostKey) + require.NoError(t, err) + + serverConfig := &server.Config{ + HostKeyPEM: hostKey, + JWT: &server.JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + }, + } + sshServer := server.New(serverConfig) + sshServer.SetAllowRootLogin(true) + + sshServerAddr := server.StartTestServer(t, sshServer) + defer func() { _ = sshServer.Stop() }() + + mockDaemon := startMockDaemon(t) + defer mockDaemon.stop() + + host, portStr, err := net.SplitHostPort(sshServerAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + mockDaemon.setHostKey(host, hostPubKey) + + validToken := generateValidJWT(t, privateKey, issuer, audience) + mockDaemon.setJWTToken(validToken) + + proxyInstance, err := New(mockDaemon.addr, host, port, nil) + require.NoError(t, err) + + clientConn, proxyConn := net.Pipe() + defer func() { _ = clientConn.Close() }() + + origStdin := os.Stdin + origStdout := os.Stdout + defer func() { + os.Stdin = origStdin + os.Stdout = origStdout + }() + + stdinReader, stdinWriter, err := os.Pipe() + require.NoError(t, err) + stdoutReader, stdoutWriter, err := os.Pipe() + require.NoError(t, err) + + os.Stdin = stdinReader + os.Stdout = stdoutWriter + + go func() { + _, _ = io.Copy(stdinWriter, proxyConn) + }() + go func() { + _, _ = io.Copy(proxyConn, stdoutReader) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + connectErrCh := make(chan error, 1) + go func() { + connectErrCh <- proxyInstance.Connect(ctx) + }() + + sshConfig := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{}, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 3 * time.Second, + } + + sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig) + require.NoError(t, err, "Should connect to proxy server") + defer func() { _ = sshClientConn.Close() }() + + sshClient := cryptossh.NewClient(sshClientConn, chans, reqs) + + session, err := sshClient.NewSession() + require.NoError(t, err, "Should create session through full proxy to backend") + + outputCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + output, err := session.Output("echo hello-from-proxy") + outputCh <- output + errCh <- err + }() + + select { + case output := <-outputCh: + err := <-errCh + require.NoError(t, err, "Command should execute successfully through proxy") + assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy") + case <-time.After(3 * time.Second): + t.Fatal("Command execution timed out") + } + + _ = session.Close() + _ = sshClient.Close() + _ = clientConn.Close() + cancel() +} + +type mockDaemonServer struct { + proto.UnimplementedDaemonServiceServer + hostKeys map[string][]byte + jwtToken string +} + +func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) { + key, found := m.hostKeys[req.PeerAddress] + return &proto.GetPeerSSHHostKeyResponse{ + Found: found, + SshHostKey: key, + }, nil +} + +func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) { + return &proto.RequestJWTAuthResponse{ + CachedToken: m.jwtToken, + }, nil +} + +func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) { + return &proto.WaitJWTTokenResponse{ + Token: m.jwtToken, + }, nil +} + +type mockDaemon struct { + addr string + server *grpc.Server + impl *mockDaemonServer +} + +func startMockDaemon(t *testing.T) *mockDaemon { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + impl := &mockDaemonServer{ + hostKeys: make(map[string][]byte), + jwtToken: "test-jwt-token", + } + + grpcServer := grpc.NewServer() + proto.RegisterDaemonServiceServer(grpcServer, impl) + + go func() { + _ = grpcServer.Serve(listener) + }() + + return &mockDaemon{ + addr: listener.Addr().String(), + server: grpcServer, + impl: impl, + } +} + +func (m *mockDaemon) setHostKey(addr string, pubKey []byte) { + m.impl.hostKeys[addr] = pubKey +} + +func (m *mockDaemon) setJWTToken(token string) { + m.impl.jwtToken = token +} + +func (m *mockDaemon) stop() { + if m.server != nil { + m.server.Stop() + } +} + +func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey { + t.Helper() + pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes) + require.NoError(t, err) + return pubKey +} + +func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) { + t.Helper() + privateKey, jwksJSON := generateTestJWKS(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(jwksJSON); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + })) + + return server, privateKey, server.URL +} + +func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) { + t.Helper() + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey := &privateKey.PublicKey + n := publicKey.N.Bytes() + e := publicKey.E + + jwk := nbjwt.JSONWebKey{ + Kty: "RSA", + Kid: "test-key-id", + Use: "sig", + N: base64.RawURLEncoding.EncodeToString(n), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()), + } + + jwks := nbjwt.Jwks{ + Keys: []nbjwt.JSONWebKey{jwk}, + } + + jwksJSON, err := json.Marshal(jwks) + require.NoError(t, err) + + return privateKey, jwksJSON +} + +func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string { + t.Helper() + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "test-key-id" + + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + return tokenString +} diff --git a/client/ssh/server.go b/client/ssh/server.go deleted file mode 100644 index 8c5db2547..000000000 --- a/client/ssh/server.go +++ /dev/null @@ -1,280 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "io" - "net" - "os" - "os/exec" - "os/user" - "runtime" - "strings" - "sync" - "time" - - "github.com/creack/pty" - "github.com/gliderlabs/ssh" - log "github.com/sirupsen/logrus" -) - -// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server -const DefaultSSHPort = 44338 - -// TerminalTimeout is the timeout for terminal session to be ready -const TerminalTimeout = 10 * time.Second - -// TerminalBackoffDelay is the delay between terminal session readiness checks -const TerminalBackoffDelay = 500 * time.Millisecond - -// DefaultSSHServer is a function that creates DefaultServer -func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { - return newDefaultServer(hostKeyPEM, addr) -} - -// Server is an interface of SSH server -type Server interface { - // Stop stops SSH server. - Stop() error - // Start starts SSH server. Blocking - Start() error - // RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys - RemoveAuthorizedKey(peer string) - // AddAuthorizedKey add a given peer key to server authorized keys - AddAuthorizedKey(peer, newKey string) error -} - -// DefaultServer is the embedded NetBird SSH server -type DefaultServer struct { - listener net.Listener - // authorizedKeys is ssh pub key indexed by peer WireGuard public key - authorizedKeys map[string]ssh.PublicKey - mu sync.Mutex - hostKeyPEM []byte - sessions []ssh.Session -} - -// newDefaultServer creates new server with provided host key -func newDefaultServer(hostKeyPEM []byte, addr string) (*DefaultServer, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - allowedKeys := make(map[string]ssh.PublicKey) - return &DefaultServer{listener: ln, mu: sync.Mutex{}, hostKeyPEM: hostKeyPEM, authorizedKeys: allowedKeys, sessions: make([]ssh.Session, 0)}, nil -} - -// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys -func (srv *DefaultServer) RemoveAuthorizedKey(peer string) { - srv.mu.Lock() - defer srv.mu.Unlock() - - delete(srv.authorizedKeys, peer) -} - -// AddAuthorizedKey add a given peer key to server authorized keys -func (srv *DefaultServer) AddAuthorizedKey(peer, newKey string) error { - srv.mu.Lock() - defer srv.mu.Unlock() - - parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey)) - if err != nil { - return err - } - - srv.authorizedKeys[peer] = parsedKey - return nil -} - -// Stop stops SSH server. -func (srv *DefaultServer) Stop() error { - srv.mu.Lock() - defer srv.mu.Unlock() - err := srv.listener.Close() - if err != nil { - return err - } - for _, session := range srv.sessions { - err := session.Close() - if err != nil { - log.Warnf("failed closing SSH session from %v", err) - } - } - - return nil -} - -func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { - srv.mu.Lock() - defer srv.mu.Unlock() - - for _, allowed := range srv.authorizedKeys { - if ssh.KeysEqual(allowed, key) { - return true - } - } - - return false -} - -func prepareUserEnv(user *user.User, shell string) []string { - return []string{ - fmt.Sprint("SHELL=" + shell), - fmt.Sprint("USER=" + user.Username), - fmt.Sprint("HOME=" + user.HomeDir), - } -} - -func acceptEnv(s string) bool { - split := strings.Split(s, "=") - if len(split) != 2 { - return false - } - return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_") -} - -// sessionHandler handles SSH session post auth -func (srv *DefaultServer) sessionHandler(session ssh.Session) { - srv.mu.Lock() - srv.sessions = append(srv.sessions, session) - srv.mu.Unlock() - - defer func() { - err := session.Close() - if err != nil { - return - } - }() - - log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String()) - - localUser, err := userNameLookup(session.User()) - if err != nil { - _, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint - err = session.Exit(1) - if err != nil { - return - } - log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User()) - return - } - - ptyReq, winCh, isPty := session.Pty() - if isPty { - loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr()) - if err != nil { - log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String()) - return - } - cmd := exec.Command(loginCmd, loginArgs...) - go func() { - <-session.Context().Done() - if cmd.Process == nil { - return - } - err := cmd.Process.Kill() - if err != nil { - log.Debugf("failed killing SSH process %v", err) - return - } - }() - cmd.Dir = localUser.HomeDir - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) - cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...) - for _, v := range session.Environ() { - if acceptEnv(v) { - cmd.Env = append(cmd.Env, v) - } - } - - log.Debugf("Login command: %s", cmd.String()) - file, err := pty.Start(cmd) - if err != nil { - log.Errorf("failed starting SSH server: %v", err) - } - - go func() { - for win := range winCh { - setWinSize(file, win.Width, win.Height) - } - }() - - srv.stdInOut(file, session) - - err = cmd.Wait() - if err != nil { - return - } - } else { - _, err := io.WriteString(session, "only PTY is supported.\n") - if err != nil { - return - } - err = session.Exit(1) - if err != nil { - return - } - } - log.Debugf("SSH session ended") -} - -func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { - go func() { - // stdin - _, err := io.Copy(file, session) - if err != nil { - _ = session.Exit(1) - return - } - }() - - // AWS Linux 2 machines need some time to open the terminal so we need to wait for it - timer := time.NewTimer(TerminalTimeout) - for { - select { - case <-timer.C: - _, _ = session.Write([]byte("Reached timeout while opening connection\n")) - _ = session.Exit(1) - return - default: - // stdout - writtenBytes, err := io.Copy(session, file) - if err != nil && writtenBytes != 0 { - _ = session.Exit(0) - return - } - time.Sleep(TerminalBackoffDelay) - } - } -} - -// Start starts SSH server. Blocking -func (srv *DefaultServer) Start() error { - log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String()) - - publicKeyOption := ssh.PublicKeyAuth(srv.publicKeyHandler) - hostKeyPEM := ssh.HostKeyPEM(srv.hostKeyPEM) - err := ssh.Serve(srv.listener, srv.sessionHandler, publicKeyOption, hostKeyPEM) - if err != nil { - return err - } - - return nil -} - -func getUserShell(userID string) string { - if runtime.GOOS == "linux" { - output, _ := exec.Command("getent", "passwd", userID).Output() - line := strings.SplitN(string(output), ":", 10) - if len(line) > 6 { - return strings.TrimSpace(line[6]) - } - } - - shell := os.Getenv("SHELL") - if shell == "" { - shell = "/bin/sh" - } - return shell -} diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go new file mode 100644 index 000000000..7a01ce4f6 --- /dev/null +++ b/client/ssh/server/command_execution.go @@ -0,0 +1,206 @@ +package server + +import ( + "errors" + "fmt" + "io" + "os" + "os/exec" + "time" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// handleCommand executes an SSH command with privilege validation +func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) { + hasPty := winCh != nil + + commandType := "command" + if hasPty { + commandType = "Pty command" + } + + logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command())) + + execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty) + if err != nil { + logger.Errorf("%s creation failed: %v", commandType, err) + + errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType) + if hasPty { + errorMsg += " with Pty" + } + errorMsg += "\n" + + if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil { + logger.Debugf(errWriteSession, writeErr) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return + } + + if !hasPty { + if s.executeCommand(logger, session, execCmd, cleanup) { + logger.Debugf("%s execution completed", commandType) + } + return + } + + defer cleanup() + + ptyReq, _, _ := session.Pty() + if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) { + logger.Debugf("%s execution completed", commandType) + } +} + +func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { + localUser := privilegeResult.User + if localUser == nil { + return nil, nil, errors.New("no user in privilege result") + } + + // If PTY requested but su doesn't support --pty, skip su and use executor + // This ensures PTY functionality is provided (executor runs within our allocated PTY) + if hasPty && !s.suSupportsPty { + log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") + cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + if err != nil { + return nil, nil, fmt.Errorf("create command with privileges: %w", err) + } + cmd.Env = s.prepareCommandEnv(localUser, session) + return cmd, cleanup, nil + } + + // Try su first for system integration (PAM/audit) when privileged + cmd, err := s.createSuCommand(session, localUser, hasPty) + if err != nil || privilegeResult.UsedFallback { + log.Debugf("su command failed, falling back to executor: %v", err) + cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + if err != nil { + return nil, nil, fmt.Errorf("create command with privileges: %w", err) + } + cmd.Env = s.prepareCommandEnv(localUser, session) + return cmd, cleanup, nil + } + + cmd.Env = s.prepareCommandEnv(localUser, session) + return cmd, func() {}, nil +} + +// executeCommand executes the command and handles I/O and exit codes +func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool { + defer cleanup() + + s.setupProcessGroup(execCmd) + + stdinPipe, err := execCmd.StdinPipe() + if err != nil { + logger.Errorf("create stdin pipe: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + execCmd.Stdout = session + execCmd.Stderr = session.Stderr() + + if execCmd.Dir != "" { + if _, err := os.Stat(execCmd.Dir); err != nil { + logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err) + execCmd.Dir = "/" + } + } + + if err := execCmd.Start(); err != nil { + logger.Errorf("command start failed: %v", err) + // no user message for exec failure, just exit + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + go s.handleCommandIO(logger, stdinPipe, session) + return s.waitForCommandCleanup(logger, session, execCmd) +} + +// handleCommandIO manages stdin/stdout copying in a goroutine +func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) { + defer func() { + if err := stdinPipe.Close(); err != nil { + logger.Debugf("stdin pipe close error: %v", err) + } + }() + if _, err := io.Copy(stdinPipe, session); err != nil { + logger.Debugf("stdin copy error: %v", err) + } +} + +// waitForCommandCleanup waits for command completion with session disconnect handling +func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool { + ctx := session.Context() + done := make(chan error, 1) + go func() { + done <- execCmd.Wait() + }() + + select { + case <-ctx.Done(): + logger.Debugf("session cancelled, terminating command") + s.killProcessGroup(execCmd) + + select { + case err := <-done: + logger.Tracef("command terminated after session cancellation: %v", err) + case <-time.After(5 * time.Second): + logger.Warnf("command did not terminate within 5 seconds after session cancellation") + } + + if err := session.Exit(130); err != nil { + logSessionExitError(logger, err) + } + return false + + case err := <-done: + return s.handleCommandCompletion(logger, session, err) + } +} + +// handleCommandCompletion handles command completion +func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool { + if err != nil { + logger.Debugf("command execution failed: %v", err) + s.handleSessionExit(session, err, logger) + return false + } + + s.handleSessionExit(session, nil, logger) + return true +} + +// handleSessionExit handles command errors and sets appropriate exit codes +func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) { + if err == nil { + if err := session.Exit(0); err != nil { + logSessionExitError(logger, err) + } + return + } + + var exitError *exec.ExitError + if errors.As(err, &exitError) { + if err := session.Exit(exitError.ExitCode()); err != nil { + logSessionExitError(logger, err) + } + } else { + logger.Debugf("non-exit error in command execution: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + } +} diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go new file mode 100644 index 000000000..6473f8273 --- /dev/null +++ b/client/ssh/server/command_execution_js.go @@ -0,0 +1,52 @@ +//go:build js + +package server + +import ( + "context" + "errors" + "os/exec" + "os/user" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform") + +// createSuCommand is not supported on JS/WASM +func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { + return nil, errNotSupported +} + +// createExecutorCommand is not supported on JS/WASM +func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { + return nil, nil, errNotSupported +} + +// prepareCommandEnv is not supported on JS/WASM +func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string { + return nil +} + +// setupProcessGroup is not supported on JS/WASM +func (s *Server) setupProcessGroup(_ *exec.Cmd) { +} + +// killProcessGroup is not supported on JS/WASM +func (s *Server) killProcessGroup(*exec.Cmd) { +} + +// detectSuPtySupport always returns false on JS/WASM +func (s *Server) detectSuPtySupport(context.Context) bool { + return false +} + +// executeCommandWithPty is not supported on JS/WASM +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + logger.Errorf("PTY command execution not supported on JS/WASM") + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false +} diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go new file mode 100644 index 000000000..da059fed9 --- /dev/null +++ b/client/ssh/server/command_execution_unix.go @@ -0,0 +1,329 @@ +//go:build unix + +package server + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "os/exec" + "os/user" + "strings" + "sync" + "syscall" + "time" + + "github.com/creack/pty" + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// ptyManager manages Pty file operations with thread safety +type ptyManager struct { + file *os.File + mu sync.RWMutex + closed bool + closeErr error + once sync.Once +} + +func newPtyManager(file *os.File) *ptyManager { + return &ptyManager{file: file} +} + +func (pm *ptyManager) Close() error { + pm.once.Do(func() { + pm.mu.Lock() + pm.closed = true + pm.closeErr = pm.file.Close() + pm.mu.Unlock() + }) + pm.mu.RLock() + defer pm.mu.RUnlock() + return pm.closeErr +} + +func (pm *ptyManager) Setsize(ws *pty.Winsize) error { + pm.mu.RLock() + defer pm.mu.RUnlock() + if pm.closed { + return errors.New("pty is closed") + } + return pty.Setsize(pm.file, ws) +} + +func (pm *ptyManager) File() *os.File { + return pm.file +} + +// detectSuPtySupport checks if su supports the --pty flag +func (s *Server) detectSuPtySupport(ctx context.Context) bool { + ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + cmd := exec.CommandContext(ctx, "su", "--help") + output, err := cmd.CombinedOutput() + if err != nil { + log.Debugf("su --help failed (may not support --help): %v", err) + return false + } + + supported := strings.Contains(string(output), "--pty") + log.Debugf("su --pty support detected: %v", supported) + return supported +} + +// createSuCommand creates a command using su -l -c for privilege switching +func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { + suPath, err := exec.LookPath("su") + if err != nil { + return nil, fmt.Errorf("su command not available: %w", err) + } + + command := session.RawCommand() + if command == "" { + return nil, fmt.Errorf("no command specified for su execution") + } + + args := []string{"-l"} + if hasPty && s.suSupportsPty { + args = append(args, "--pty") + } + args = append(args, localUser.Username, "-c", command) + + cmd := exec.CommandContext(session.Context(), suPath, args...) + cmd.Dir = localUser.HomeDir + + return cmd, nil +} + +// getShellCommandArgs returns the shell command and arguments for executing a command string +func (s *Server) getShellCommandArgs(shell, cmdString string) []string { + if cmdString == "" { + return []string{shell, "-l"} + } + return []string{shell, "-l", "-c", cmdString} +} + +// prepareCommandEnv prepares environment variables for command execution on Unix +func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { + env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) + env = append(env, prepareSSHEnv(session)...) + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env +} + +// executeCommandWithPty executes a command with PTY allocation +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + termType := ptyReq.Term + if termType == "" { + termType = "xterm-256color" + } + execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType)) + + return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) +} + +func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) + if err != nil { + logger.Errorf("Pty command creation failed: %v", err) + errorMsg := "User switching failed - login command not available\r\n" + if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil { + logger.Debugf(errWriteSession, writeErr) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + logger.Infof("starting interactive shell: %s", execCmd.Path) + return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) +} + +// runPtyCommand runs a command with PTY management (common code for interactive and command execution) +func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq) + if err != nil { + logger.Errorf("Pty start failed: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + ptyMgr := newPtyManager(ptmx) + defer func() { + if err := ptyMgr.Close(); err != nil { + logger.Debugf("Pty close error: %v", err) + } + }() + + go s.handlePtyWindowResize(logger, session, ptyMgr, winCh) + s.handlePtyIO(logger, session, ptyMgr) + s.waitForPtyCompletion(logger, session, execCmd, ptyMgr) + return true +} + +func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) { + winSize := &pty.Winsize{ + Cols: uint16(ptyReq.Window.Width), + Rows: uint16(ptyReq.Window.Height), + } + if winSize.Cols == 0 { + winSize.Cols = 80 + } + if winSize.Rows == 0 { + winSize.Rows = 24 + } + + ptmx, err := pty.StartWithSize(execCmd, winSize) + if err != nil { + return nil, fmt.Errorf("start Pty: %w", err) + } + + return ptmx, nil +} + +func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) { + for { + select { + case <-session.Context().Done(): + return + case win, ok := <-winCh: + if !ok { + return + } + if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil { + logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err) + } + } + } +} + +func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) { + ptmx := ptyMgr.File() + + go func() { + if _, err := io.Copy(ptmx, session); err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { + logger.Warnf("Pty input copy error: %v", err) + } + } + }() + + go func() { + defer func() { + if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { + logger.Debugf("session close error: %v", err) + } + }() + if _, err := io.Copy(session, ptmx); err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { + logger.Warnf("Pty output copy error: %v", err) + } + } + }() +} + +func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) { + ctx := session.Context() + done := make(chan error, 1) + go func() { + done <- execCmd.Wait() + }() + + select { + case <-ctx.Done(): + s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done) + case err := <-done: + s.handlePtyCommandCompletion(logger, session, err) + } +} + +func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) { + logger.Debugf("Pty session cancelled, terminating command") + if err := ptyMgr.Close(); err != nil { + logger.Debugf("Pty close during session cancellation: %v", err) + } + + s.killProcessGroup(execCmd) + + select { + case err := <-done: + if err != nil { + logger.Debugf("Pty command terminated after session cancellation with error: %v", err) + } else { + logger.Debugf("Pty command terminated after session cancellation") + } + case <-time.After(5 * time.Second): + logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation") + } + + if err := session.Exit(130); err != nil { + logSessionExitError(logger, err) + } +} + +func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) { + if err != nil { + logger.Debugf("Pty command execution failed: %v", err) + s.handleSessionExit(session, err, logger) + return + } + + // Normal completion + logger.Debugf("Pty command completed successfully") + if err := session.Exit(0); err != nil { + logSessionExitError(logger, err) + } +} + +func (s *Server) setupProcessGroup(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } +} + +func (s *Server) killProcessGroup(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + + logger := log.WithField("pid", cmd.Process.Pid) + pgid := cmd.Process.Pid + + if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil { + logger.Debugf("kill process group SIGTERM: %v", err) + return + } + + const gracePeriod = 500 * time.Millisecond + const checkInterval = 50 * time.Millisecond + + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + timeout := time.After(gracePeriod) + + for { + select { + case <-timeout: + if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil { + logger.Debugf("kill process group SIGKILL: %v", err) + } + return + case <-ticker.C: + if err := syscall.Kill(-pgid, 0); err != nil { + return + } + } + } +} diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go new file mode 100644 index 000000000..37b3ae0ee --- /dev/null +++ b/client/ssh/server/command_execution_windows.go @@ -0,0 +1,430 @@ +package server + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "strings" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/ssh/server/winpty" +) + +// getUserEnvironment retrieves the Windows environment for the target user. +// Follows OpenSSH's resilient approach with graceful degradation on failures. +func (s *Server) getUserEnvironment(username, domain string) ([]string, error) { + userToken, err := s.getUserToken(username, domain) + if err != nil { + return nil, fmt.Errorf("get user token: %w", err) + } + defer func() { + if err := windows.CloseHandle(userToken); err != nil { + log.Debugf("close user token: %v", err) + } + }() + + return s.getUserEnvironmentWithToken(userToken, username, domain) +} + +// getUserEnvironmentWithToken retrieves the Windows environment using an existing token. +func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) { + userProfile, err := s.loadUserProfile(userToken, username, domain) + if err != nil { + log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) + userProfile = fmt.Sprintf("C:\\Users\\%s", username) + } + + envMap := make(map[string]string) + + if err := s.loadSystemEnvironment(envMap); err != nil { + log.Debugf("failed to load system environment from registry: %v", err) + } + + s.setUserEnvironmentVariables(envMap, userProfile, username, domain) + + var env []string + for key, value := range envMap { + env = append(env, key+"="+value) + } + + return env, nil +} + +// getUserToken creates a user token for the specified user. +func (s *Server) getUserToken(username, domain string) (windows.Handle, error) { + privilegeDropper := NewPrivilegeDropper() + token, err := privilegeDropper.createToken(username, domain) + if err != nil { + return 0, fmt.Errorf("generate S4U user token: %w", err) + } + return token, nil +} + +// loadUserProfile loads the Windows user profile and returns the profile path. +func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) { + usernamePtr, err := windows.UTF16PtrFromString(username) + if err != nil { + return "", fmt.Errorf("convert username to UTF-16: %w", err) + } + + var domainUTF16 *uint16 + if domain != "" && domain != "." { + domainUTF16, err = windows.UTF16PtrFromString(domain) + if err != nil { + return "", fmt.Errorf("convert domain to UTF-16: %w", err) + } + } + + type profileInfo struct { + dwSize uint32 + dwFlags uint32 + lpUserName *uint16 + lpProfilePath *uint16 + lpDefaultPath *uint16 + lpServerName *uint16 + lpPolicyPath *uint16 + hProfile windows.Handle + } + + const PI_NOUI = 0x00000001 + + profile := profileInfo{ + dwSize: uint32(unsafe.Sizeof(profileInfo{})), + dwFlags: PI_NOUI, + lpUserName: usernamePtr, + lpServerName: domainUTF16, + } + + userenv := windows.NewLazySystemDLL("userenv.dll") + loadUserProfileW := userenv.NewProc("LoadUserProfileW") + + ret, _, err := loadUserProfileW.Call( + uintptr(userToken), + uintptr(unsafe.Pointer(&profile)), + ) + + if ret == 0 { + return "", fmt.Errorf("LoadUserProfileW: %w", err) + } + + if profile.lpProfilePath == nil { + return "", fmt.Errorf("LoadUserProfileW returned null profile path") + } + + profilePath := windows.UTF16PtrToString(profile.lpProfilePath) + return profilePath, nil +} + +// loadSystemEnvironment loads system-wide environment variables from registry. +func (s *Server) loadSystemEnvironment(envMap map[string]string) error { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, + `SYSTEM\CurrentControlSet\Control\Session Manager\Environment`, + registry.QUERY_VALUE) + if err != nil { + return fmt.Errorf("open system environment registry key: %w", err) + } + defer func() { + if err := key.Close(); err != nil { + log.Debugf("close registry key: %v", err) + } + }() + + return s.readRegistryEnvironment(key, envMap) +} + +// readRegistryEnvironment reads environment variables from a registry key. +func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error { + names, err := key.ReadValueNames(0) + if err != nil { + return fmt.Errorf("read registry value names: %w", err) + } + + for _, name := range names { + value, valueType, err := key.GetStringValue(name) + if err != nil { + log.Debugf("failed to read registry value %s: %v", name, err) + continue + } + + finalValue := s.expandRegistryValue(value, valueType, name) + s.setEnvironmentVariable(envMap, name, finalValue) + } + + return nil +} + +// expandRegistryValue expands registry values if they contain environment variables. +func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string { + if valueType != registry.EXPAND_SZ { + return value + } + + sourcePtr := windows.StringToUTF16Ptr(value) + expandedBuffer := make([]uint16, 1024) + expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer))) + if err != nil { + log.Debugf("failed to expand environment string for %s: %v", name, err) + return value + } + + // If buffer was too small, retry with larger buffer + if expandedLen > uint32(len(expandedBuffer)) { + expandedBuffer = make([]uint16, expandedLen) + expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer))) + if err != nil { + log.Debugf("failed to expand environment string for %s on retry: %v", name, err) + return value + } + } + + if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) { + return windows.UTF16ToString(expandedBuffer[:expandedLen-1]) + } + return value +} + +// setEnvironmentVariable sets an environment variable with special handling for PATH. +func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) { + upperName := strings.ToUpper(name) + + if upperName == "PATH" { + if existing, exists := envMap["PATH"]; exists && existing != value { + envMap["PATH"] = existing + ";" + value + } else { + envMap["PATH"] = value + } + } else { + envMap[upperName] = value + } +} + +// setUserEnvironmentVariables sets critical user-specific environment variables. +func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) { + envMap["USERPROFILE"] = userProfile + + if len(userProfile) >= 2 && userProfile[1] == ':' { + envMap["HOMEDRIVE"] = userProfile[:2] + envMap["HOMEPATH"] = userProfile[2:] + } + + envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming") + envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local") + + tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp") + envMap["TEMP"] = tempDir + envMap["TMP"] = tempDir + + envMap["USERNAME"] = username + if domain != "" && domain != "." { + envMap["USERDOMAIN"] = domain + envMap["USERDNSDOMAIN"] = domain + } + + systemVars := []string{ + "PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION", + "SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT", + "PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC", + } + + for _, sysVar := range systemVars { + if sysValue := os.Getenv(sysVar); sysValue != "" { + envMap[sysVar] = sysValue + } + } +} + +// prepareCommandEnv prepares environment variables for command execution on Windows +func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { + username, domain := s.parseUsername(localUser.Username) + userEnv, err := s.getUserEnvironment(username, domain) + if err != nil { + log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err) + env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) + env = append(env, prepareSSHEnv(session)...) + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env + } + + env := userEnv + env = append(env, prepareSSHEnv(session)...) + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env +} + +func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + if privilegeResult.User == nil { + logger.Errorf("no user in privilege result") + return false + } + + cmd := session.Command() + shell := getUserShell(privilegeResult.User.Uid) + + if len(cmd) == 0 { + logger.Infof("starting interactive shell: %s", shell) + } else { + logger.Infof("executing command: %s", safeLogCommand(cmd)) + } + + s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd) + return true +} + +// getShellCommandArgs returns the shell command and arguments for executing a command string +func (s *Server) getShellCommandArgs(shell, cmdString string) []string { + if cmdString == "" { + return []string{shell, "-NoLogo"} + } + return []string{shell, "-Command", cmdString} +} + +func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) { + logger.Info("starting interactive shell") + s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand()) +} + +type PtyExecutionRequest struct { + Shell string + Command string + Width int + Height int + Username string + Domain string +} + +func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error { + log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", + req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height) + + privilegeDropper := NewPrivilegeDropper() + userToken, err := privilegeDropper.createToken(req.Username, req.Domain) + if err != nil { + return fmt.Errorf("create user token: %w", err) + } + defer func() { + if err := windows.CloseHandle(userToken); err != nil { + log.Debugf("close user token: %v", err) + } + }() + + server := &Server{} + userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain) + if err != nil { + log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) + userEnv = os.Environ() + } + + workingDir := getUserHomeFromEnv(userEnv) + if workingDir == "" { + workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username) + } + + ptyConfig := winpty.PtyConfig{ + Shell: req.Shell, + Command: req.Command, + Width: req.Width, + Height: req.Height, + WorkingDir: workingDir, + } + + userConfig := winpty.UserConfig{ + Token: userToken, + Environment: userEnv, + } + + log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) + return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig) +} + +func getUserHomeFromEnv(env []string) string { + for _, envVar := range env { + if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" { + return envVar[12:] + } + } + return "" +} + +func (s *Server) setupProcessGroup(_ *exec.Cmd) { + // Windows doesn't support process groups in the same way as Unix + // Process creation groups are handled differently +} + +func (s *Server) killProcessGroup(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + + logger := log.WithField("pid", cmd.Process.Pid) + + if err := cmd.Process.Kill(); err != nil { + logger.Debugf("kill process failed: %v", err) + } +} + +// detectSuPtySupport always returns false on Windows as su is not available +func (s *Server) detectSuPtySupport(context.Context) bool { + return false +} + +// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { + command := session.RawCommand() + if command == "" { + logger.Error("no command specified for PTY execution") + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command) +} + +// executeConPtyCommand executes a command using ConPty (common for interactive and command execution) +func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool { + localUser := privilegeResult.User + if localUser == nil { + logger.Errorf("no user in privilege result") + return false + } + + username, domain := s.parseUsername(localUser.Username) + shell := getUserShell(localUser.Uid) + + req := PtyExecutionRequest{ + Shell: shell, + Command: command, + Width: ptyReq.Window.Width, + Height: ptyReq.Window.Height, + Username: username, + Domain: domain, + } + + if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil { + logger.Errorf("ConPty execution failed: %v", err) + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false + } + + logger.Debug("ConPty execution completed") + return true +} diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go new file mode 100644 index 000000000..34ffccfd2 --- /dev/null +++ b/client/ssh/server/compatibility_test.go @@ -0,0 +1,722 @@ +package server + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "fmt" + "io" + "net" + "os" + "os/exec" + "runtime" + "strings" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/testutil" +) + +// TestMain handles package-level setup and cleanup +func TestMain(m *testing.M) { + // Guard against infinite recursion when test binary is called as "netbird ssh exec" + // This happens when running tests as non-privileged user with fallback + if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { + // Just exit with error to break the recursion + fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") + os.Exit(1) + } + + // Run tests + code := m.Run() + + // Cleanup any created test users + testutil.CleanupTestUsers() + + os.Exit(code) +} + +// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client +func TestSSHServerCompatibility(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH compatibility tests in short mode") + } + + // Check if ssh binary is available + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Set up SSH server - use our existing key generation for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Generate OpenSSH-compatible keys for client + clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Create temporary key files for SSH client + clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH) + defer cleanupKey() + + // Extract host and port from server address + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + // Get appropriate user for SSH connection (handle system accounts) + username := testutil.GetTestUsername(t) + + t.Run("basic command execution", func(t *testing.T) { + testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username) + }) + + t.Run("interactive command", func(t *testing.T) { + testSSHInteractiveCommand(t, host, portStr, clientKeyFile) + }) + + t.Run("port forwarding", func(t *testing.T) { + testSSHPortForwarding(t, host, portStr, clientKeyFile) + }) +} + +// testSSHCommandExecutionWithUser tests basic command execution with system SSH client using specified user. +func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username string) { + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "echo", "hello_world") + + output, err := cmd.CombinedOutput() + + if err != nil { + t.Logf("SSH command failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully") +} + +// testSSHInteractiveCommand tests interactive shell session. +func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host)) + + stdin, err := cmd.StdinPipe() + if err != nil { + t.Skipf("Cannot create stdin pipe: %v", err) + return + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Skipf("Cannot create stdout pipe: %v", err) + return + } + + err = cmd.Start() + if err != nil { + t.Logf("Cannot start SSH session: %v", err) + return + } + + go func() { + defer func() { + if err := stdin.Close(); err != nil { + t.Logf("stdin close error: %v", err) + } + }() + time.Sleep(100 * time.Millisecond) + if _, err := stdin.Write([]byte("echo interactive_test\n")); err != nil { + t.Logf("stdin write error: %v", err) + } + time.Sleep(100 * time.Millisecond) + if _, err := stdin.Write([]byte("exit\n")); err != nil { + t.Logf("stdin write error: %v", err) + } + }() + + output, err := io.ReadAll(stdout) + if err != nil { + t.Logf("Cannot read SSH output: %v", err) + } + + err = cmd.Wait() + if err != nil { + t.Logf("SSH interactive session error: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + assert.Contains(t, string(output), "interactive_test", "Interactive SSH session should work") +} + +// testSSHPortForwarding tests port forwarding compatibility. +func testSSHPortForwarding(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + testServer, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer testServer.Close() + + testServerAddr := testServer.Addr().String() + expectedResponse := "HTTP/1.1 200 OK\r\nContent-Length: 21\r\n\r\nCompatibility Test OK" + + go func() { + for { + conn, err := testServer.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { + if err := c.Close(); err != nil { + t.Logf("test server connection close error: %v", err) + } + }() + buf := make([]byte, 1024) + if _, err := c.Read(buf); err != nil { + t.Logf("Test server read error: %v", err) + } + if _, err := c.Write([]byte(expectedResponse)); err != nil { + t.Logf("Test server write error: %v", err) + } + }(conn) + } + }() + + localListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + localAddr := localListener.Addr().String() + localListener.Close() + + _, localPort, err := net.SplitHostPort(localAddr) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + forwardSpec := fmt.Sprintf("%s:%s", localPort, testServerAddr) + cmd := exec.CommandContext(ctx, "ssh", + "-i", keyFile, + "-p", port, + "-L", forwardSpec, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-N", + fmt.Sprintf("%s@%s", username, host)) + + err = cmd.Start() + if err != nil { + t.Logf("Cannot start SSH port forwarding: %v", err) + return + } + + defer func() { + if cmd.Process != nil { + if err := cmd.Process.Kill(); err != nil { + t.Logf("process kill error: %v", err) + } + } + if err := cmd.Wait(); err != nil { + t.Logf("process wait after kill: %v", err) + } + }() + + time.Sleep(500 * time.Millisecond) + + conn, err := net.DialTimeout("tcp", localAddr, 3*time.Second) + if err != nil { + t.Logf("Cannot connect to forwarded port: %v", err) + return + } + defer func() { + if err := conn.Close(); err != nil { + t.Logf("forwarded connection close error: %v", err) + } + }() + + request := "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + _, err = conn.Write([]byte(request)) + require.NoError(t, err) + + if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { + log.Debugf("failed to set read deadline: %v", err) + } + response := make([]byte, len(expectedResponse)) + n, err := io.ReadFull(conn, response) + if err != nil { + t.Logf("Cannot read forwarded response: %v", err) + return + } + + assert.Equal(t, len(expectedResponse), n, "Should read expected number of bytes") + assert.Equal(t, expectedResponse, string(response), "Should get correct HTTP response through SSH port forwarding") +} + +// isSSHClientAvailable checks if the ssh binary is available +func isSSHClientAvailable() bool { + _, err := exec.LookPath("ssh") + return err == nil +} + +// generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use. +func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) { + // Check if ssh-keygen is available + if _, err := exec.LookPath("ssh-keygen"); err != nil { + // Fall back to our existing key generation and try to convert + return generateOpenSSHKeyFallback() + } + + // Create temporary file for ssh-keygen + tempFile, err := os.CreateTemp("", "ssh_keygen_*") + if err != nil { + return nil, nil, fmt.Errorf("create temp file: %w", err) + } + keyPath := tempFile.Name() + tempFile.Close() + + // Remove the temp file so ssh-keygen can create it + if err := os.Remove(keyPath); err != nil { + t.Logf("failed to remove key file: %v", err) + } + + // Clean up temp files + defer func() { + if err := os.Remove(keyPath); err != nil { + t.Logf("failed to cleanup key file: %v", err) + } + if err := os.Remove(keyPath + ".pub"); err != nil { + t.Logf("failed to cleanup public key file: %v", err) + } + }() + + // Generate key using ssh-keygen + cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", keyPath, "-N", "", "-q") + output, err := cmd.CombinedOutput() + if err != nil { + return nil, nil, fmt.Errorf("ssh-keygen failed: %w, output: %s", err, string(output)) + } + + // Read private key + privKeyBytes, err := os.ReadFile(keyPath) + if err != nil { + return nil, nil, fmt.Errorf("read private key: %w", err) + } + + // Read public key + pubKeyBytes, err := os.ReadFile(keyPath + ".pub") + if err != nil { + return nil, nil, fmt.Errorf("read public key: %w", err) + } + + return privKeyBytes, pubKeyBytes, nil +} + +// generateOpenSSHKeyFallback falls back to generating keys using our existing method +func generateOpenSSHKeyFallback() ([]byte, []byte, error) { + // Generate shared.ED25519 key pair using our existing method + _, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("generate key: %w", err) + } + + // Convert to SSH format + sshPrivKey, err := ssh.NewSignerFromKey(privKey) + if err != nil { + return nil, nil, fmt.Errorf("create signer: %w", err) + } + + // For the fallback, just use our PKCS#8 format and hope it works + // This won't be in OpenSSH format but might still work with some SSH clients + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + if err != nil { + return nil, nil, fmt.Errorf("generate fallback key: %w", err) + } + + // Get public key in SSH format + sshPubKey := ssh.MarshalAuthorizedKey(sshPrivKey.PublicKey()) + + return hostKey, sshPubKey, nil +} + +// createTempKeyFileFromBytes creates a temporary SSH private key file from raw bytes +func createTempKeyFileFromBytes(t *testing.T, keyBytes []byte) (string, func()) { + t.Helper() + + tempFile, err := os.CreateTemp("", "ssh_test_key_*") + require.NoError(t, err) + + _, err = tempFile.Write(keyBytes) + require.NoError(t, err) + + err = tempFile.Close() + require.NoError(t, err) + + // Set proper permissions for SSH key (readable by owner only) + err = os.Chmod(tempFile.Name(), 0600) + require.NoError(t, err) + + cleanup := func() { + _ = os.Remove(tempFile.Name()) + } + + return tempFile.Name(), cleanup +} + +// createTempKeyFile creates a temporary SSH private key file (for backward compatibility) +func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) { + return createTempKeyFileFromBytes(t, privateKey) +} + +// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility +func TestSSHServerFeatureCompatibility(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH feature compatibility tests in short mode") + } + + if runtime.GOOS == "windows" && testutil.IsCI() { + t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Test various SSH features + testCases := []struct { + name string + testFunc func(t *testing.T, host, port, keyFile string) + description string + }{ + { + name: "command_with_flags", + testFunc: testCommandWithFlags, + description: "Commands with flags should work like standard SSH", + }, + { + name: "environment_variables", + testFunc: testEnvironmentVariables, + description: "Environment variables should be available", + }, + { + name: "exit_codes", + testFunc: testExitCodes, + description: "Exit codes should be properly handled", + }, + } + + // Set up SSH server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.testFunc(t, host, portStr, clientKeyFile) + }) + } +} + +// testCommandWithFlags tests that commands with flags work properly +func testCommandWithFlags(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Test ls with flags + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "ls", "-la", "/tmp") + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Command with flags failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + // Should not be empty and should not contain error messages + assert.NotEmpty(t, string(output), "ls -la should produce output") + assert.NotContains(t, strings.ToLower(string(output)), "command not found", "Command should be executed") +} + +// testEnvironmentVariables tests that environment is properly set up +func testEnvironmentVariables(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "echo", "$HOME") + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Environment test failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + // HOME environment variable should be available + homeOutput := strings.TrimSpace(string(output)) + assert.NotEmpty(t, homeOutput, "HOME environment variable should be set") + assert.NotEqual(t, "$HOME", homeOutput, "Environment variable should be expanded") +} + +// testExitCodes tests that exit codes are properly handled +func testExitCodes(t *testing.T, host, port, keyFile string) { + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Test successful command (exit code 0) + cmd := exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "true") // always succeeds + + err := cmd.Run() + assert.NoError(t, err, "Command with exit code 0 should succeed") + + // Test failing command (exit code 1) + cmd = exec.Command("ssh", + "-i", keyFile, + "-p", port, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + "false") // always fails + + err = cmd.Run() + assert.Error(t, err, "Command with exit code 1 should fail") + + // Check if it's the right kind of error + if exitError, ok := err.(*exec.ExitError); ok { + assert.Equal(t, 1, exitError.ExitCode(), "Exit code should be preserved") + } +} + +// TestSSHServerSecurityFeatures tests security-related SSH features +func TestSSHServerSecurityFeatures(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH security tests in short mode") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Set up SSH server with specific security settings + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + t.Run("key_authentication", func(t *testing.T) { + // Test that key authentication works + cmd := exec.Command("ssh", + "-i", clientKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-o", "PasswordAuthentication=no", + fmt.Sprintf("%s@%s", username, host), + "echo", "auth_success") + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Key authentication failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + assert.Contains(t, string(output), "auth_success", "Key authentication should work") + }) + + t.Run("any_key_accepted_in_no_auth_mode", func(t *testing.T) { + // Create a different key that shouldn't be accepted + wrongKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + wrongKeyFile, cleanupWrongKey := createTempKeyFile(t, wrongKey) + defer cleanupWrongKey() + + // Test that wrong key is rejected + cmd := exec.Command("ssh", + "-i", wrongKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-o", "PasswordAuthentication=no", + fmt.Sprintf("%s@%s", username, host), + "echo", "should_not_work") + + err = cmd.Run() + assert.NoError(t, err, "Any key should work in no-auth mode") + }) +} + +// TestCrossPlatformCompatibility tests cross-platform behavior +func TestCrossPlatformCompatibility(t *testing.T) { + if testing.Short() { + t.Skip("Skipping cross-platform compatibility tests in short mode") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + // Get appropriate user for SSH connection + username := testutil.GetTestUsername(t) + + // Set up SSH server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + // Test platform-specific commands + var testCommand string + + switch runtime.GOOS { + case "windows": + testCommand = "echo %OS%" + default: + testCommand = "uname" + } + + cmd := exec.Command("ssh", + "-i", clientKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + fmt.Sprintf("%s@%s", username, host), + testCommand) + + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Platform-specific command failed: %v", err) + t.Logf("Output: %s", string(output)) + return + } + + outputStr := strings.TrimSpace(string(output)) + t.Logf("Platform command output: %s", outputStr) + assert.NotEmpty(t, outputStr, "Platform-specific command should produce output") +} diff --git a/client/ssh/server/executor_unix.go b/client/ssh/server/executor_unix.go new file mode 100644 index 000000000..8adc824ef --- /dev/null +++ b/client/ssh/server/executor_unix.go @@ -0,0 +1,253 @@ +//go:build unix + +package server + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "strings" + "syscall" + + log "github.com/sirupsen/logrus" +) + +// Exit codes for executor process communication +const ( + ExitCodeSuccess = 0 + ExitCodePrivilegeDropFail = 10 + ExitCodeShellExecFail = 11 + ExitCodeValidationFail = 12 +) + +// ExecutorConfig holds configuration for the executor process +type ExecutorConfig struct { + UID uint32 + GID uint32 + Groups []uint32 + WorkingDir string + Shell string + Command string + PTY bool +} + +// PrivilegeDropper handles secure privilege dropping in child processes +type PrivilegeDropper struct{} + +// NewPrivilegeDropper creates a new privilege dropper +func NewPrivilegeDropper() *PrivilegeDropper { + return &PrivilegeDropper{} +} + +// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping +func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config ExecutorConfig) (*exec.Cmd, error) { + netbirdPath, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("get netbird executable path: %w", err) + } + + if err := pd.validatePrivileges(config.UID, config.GID); err != nil { + return nil, fmt.Errorf("invalid privileges: %w", err) + } + + args := []string{ + "ssh", "exec", + "--uid", fmt.Sprintf("%d", config.UID), + "--gid", fmt.Sprintf("%d", config.GID), + "--working-dir", config.WorkingDir, + "--shell", config.Shell, + } + + for _, group := range config.Groups { + args = append(args, "--groups", fmt.Sprintf("%d", group)) + } + + if config.PTY { + args = append(args, "--pty") + } + + if config.Command != "" { + args = append(args, "--cmd", config.Command) + } + + // Log executor args safely - show all args except hide the command value + safeArgs := make([]string, len(args)) + copy(safeArgs, args) + for i := 0; i < len(safeArgs)-1; i++ { + if safeArgs[i] == "--cmd" { + cmdParts := strings.Fields(safeArgs[i+1]) + safeArgs[i+1] = safeLogCommand(cmdParts) + break + } + } + log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs) + return exec.CommandContext(ctx, netbirdPath, args...), nil +} + +// DropPrivileges performs privilege dropping with thread locking for security +func (pd *PrivilegeDropper) DropPrivileges(targetUID, targetGID uint32, supplementaryGroups []uint32) error { + if err := pd.validatePrivileges(targetUID, targetGID); err != nil { + return fmt.Errorf("invalid privileges: %w", err) + } + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + originalUID := os.Geteuid() + originalGID := os.Getegid() + + if originalUID != int(targetUID) || originalGID != int(targetGID) { + if err := pd.setGroupsAndIDs(targetUID, targetGID, supplementaryGroups); err != nil { + return fmt.Errorf("set groups and IDs: %w", err) + } + } + + if err := pd.validatePrivilegeDropSuccess(targetUID, targetGID, originalUID, originalGID); err != nil { + return err + } + + log.Tracef("successfully dropped privileges to UID=%d, GID=%d", targetUID, targetGID) + return nil +} + +// setGroupsAndIDs sets the supplementary groups, GID, and UID +func (pd *PrivilegeDropper) setGroupsAndIDs(targetUID, targetGID uint32, supplementaryGroups []uint32) error { + groups := make([]int, len(supplementaryGroups)) + for i, g := range supplementaryGroups { + groups[i] = int(g) + } + + if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" { + if len(groups) == 0 || groups[0] != int(targetGID) { + groups = append([]int{int(targetGID)}, groups...) + } + } + + if err := syscall.Setgroups(groups); err != nil { + return fmt.Errorf("setgroups to %v: %w", groups, err) + } + + if err := syscall.Setgid(int(targetGID)); err != nil { + return fmt.Errorf("setgid to %d: %w", targetGID, err) + } + + if err := syscall.Setuid(int(targetUID)); err != nil { + return fmt.Errorf("setuid to %d: %w", targetUID, err) + } + + return nil +} + +// validatePrivilegeDropSuccess validates that privilege dropping was successful +func (pd *PrivilegeDropper) validatePrivilegeDropSuccess(targetUID, targetGID uint32, originalUID, originalGID int) error { + if err := pd.validatePrivilegeDropReversibility(targetUID, targetGID, originalUID, originalGID); err != nil { + return err + } + + if err := pd.validateCurrentPrivileges(targetUID, targetGID); err != nil { + return err + } + + return nil +} + +// validatePrivilegeDropReversibility ensures privileges cannot be restored +func (pd *PrivilegeDropper) validatePrivilegeDropReversibility(targetUID, targetGID uint32, originalUID, originalGID int) error { + if originalGID != int(targetGID) { + if err := syscall.Setegid(originalGID); err == nil { + return fmt.Errorf("privilege drop validation failed: able to restore original GID %d", originalGID) + } + } + if originalUID != int(targetUID) { + if err := syscall.Seteuid(originalUID); err == nil { + return fmt.Errorf("privilege drop validation failed: able to restore original UID %d", originalUID) + } + } + return nil +} + +// validateCurrentPrivileges validates the current UID and GID match the target +func (pd *PrivilegeDropper) validateCurrentPrivileges(targetUID, targetGID uint32) error { + currentUID := os.Geteuid() + if currentUID != int(targetUID) { + return fmt.Errorf("privilege drop validation failed: current UID %d, expected %d", currentUID, targetUID) + } + + currentGID := os.Getegid() + if currentGID != int(targetGID) { + return fmt.Errorf("privilege drop validation failed: current GID %d, expected %d", currentGID, targetGID) + } + + return nil +} + +// ExecuteWithPrivilegeDrop executes a command with privilege dropping, using exit codes to signal specific failures +func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config ExecutorConfig) { + log.Tracef("dropping privileges to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups) + + // TODO: Implement Pty support for executor path + if config.PTY { + config.PTY = false + } + + if err := pd.DropPrivileges(config.UID, config.GID, config.Groups); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "privilege drop failed: %v\n", err) + os.Exit(ExitCodePrivilegeDropFail) + } + + if config.WorkingDir != "" { + if err := os.Chdir(config.WorkingDir); err != nil { + log.Debugf("failed to change to working directory %s, continuing with current directory: %v", config.WorkingDir, err) + } + } + + var execCmd *exec.Cmd + if config.Command == "" { + os.Exit(ExitCodeSuccess) + } + + execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command) + execCmd.Stdin = os.Stdin + execCmd.Stdout = os.Stdout + execCmd.Stderr = os.Stderr + + cmdParts := strings.Fields(config.Command) + safeCmd := safeLogCommand(cmdParts) + log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) + if err := execCmd.Run(); err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + // Normal command exit with non-zero code - not an SSH execution error + log.Tracef("command exited with code %d", exitError.ExitCode()) + os.Exit(exitError.ExitCode()) + } + + // Actual execution failure (command not found, permission denied, etc.) + log.Debugf("command execution failed: %v", err) + os.Exit(ExitCodeShellExecFail) + } + + os.Exit(ExitCodeSuccess) +} + +// validatePrivileges validates that privilege dropping to the target UID/GID is allowed +func (pd *PrivilegeDropper) validatePrivileges(uid, gid uint32) error { + currentUID := uint32(os.Geteuid()) + currentGID := uint32(os.Getegid()) + + // Allow same-user operations (no privilege dropping needed) + if uid == currentUID && gid == currentGID { + return nil + } + + // Only root can drop privileges to other users + if currentUID != 0 { + return fmt.Errorf("cannot drop privileges from non-root user (UID %d) to UID %d", currentUID, uid) + } + + // Root can drop to any user (including root itself) + return nil +} diff --git a/client/ssh/server/executor_unix_test.go b/client/ssh/server/executor_unix_test.go new file mode 100644 index 000000000..0c5108f57 --- /dev/null +++ b/client/ssh/server/executor_unix_test.go @@ -0,0 +1,262 @@ +//go:build unix + +package server + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) { + pd := NewPrivilegeDropper() + + currentUID := uint32(os.Geteuid()) + currentGID := uint32(os.Getegid()) + + tests := []struct { + name string + uid uint32 + gid uint32 + wantErr bool + }{ + { + name: "same user - no privilege drop needed", + uid: currentUID, + gid: currentGID, + wantErr: false, + }, + { + name: "non-root to different user should fail", + uid: currentUID + 1, // Use a different UID to ensure it's actually different + gid: currentGID + 1, // Use a different GID to ensure it's actually different + wantErr: currentUID != 0, // Only fail if current user is not root + }, + { + name: "root can drop to any user", + uid: 1000, + gid: 1000, + wantErr: false, + }, + { + name: "root can stay as root", + uid: 0, + gid: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip non-root tests when running as root, and root tests when not root + if tt.name == "non-root to different user should fail" && currentUID == 0 { + t.Skip("Skipping non-root test when running as root") + } + if (tt.name == "root can drop to any user" || tt.name == "root can stay as root") && currentUID != 0 { + t.Skip("Skipping root test when not running as root") + } + + err := pd.validatePrivileges(tt.uid, tt.gid) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) { + pd := NewPrivilegeDropper() + + config := ExecutorConfig{ + UID: 1000, + GID: 1000, + Groups: []uint32{1000, 1001}, + WorkingDir: "/home/testuser", + Shell: "/bin/bash", + Command: "ls -la", + } + + cmd, err := pd.CreateExecutorCommand(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, cmd) + + // Verify the command is calling netbird ssh exec + assert.Contains(t, cmd.Args, "ssh") + assert.Contains(t, cmd.Args, "exec") + assert.Contains(t, cmd.Args, "--uid") + assert.Contains(t, cmd.Args, "1000") + assert.Contains(t, cmd.Args, "--gid") + assert.Contains(t, cmd.Args, "1000") + assert.Contains(t, cmd.Args, "--groups") + assert.Contains(t, cmd.Args, "1000") + assert.Contains(t, cmd.Args, "1001") + assert.Contains(t, cmd.Args, "--working-dir") + assert.Contains(t, cmd.Args, "/home/testuser") + assert.Contains(t, cmd.Args, "--shell") + assert.Contains(t, cmd.Args, "/bin/bash") + assert.Contains(t, cmd.Args, "--cmd") + assert.Contains(t, cmd.Args, "ls -la") +} + +func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) { + pd := NewPrivilegeDropper() + + config := ExecutorConfig{ + UID: 1000, + GID: 1000, + Groups: []uint32{1000}, + WorkingDir: "/home/testuser", + Shell: "/bin/bash", + Command: "", + } + + cmd, err := pd.CreateExecutorCommand(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, cmd) + + // Verify no command mode (command is empty so no --cmd flag) + assert.NotContains(t, cmd.Args, "--cmd") + assert.NotContains(t, cmd.Args, "--interactive") +} + +// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping +// This test requires root privileges and will be skipped if not running as root +func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip("This test requires root privileges") + } + + // Find a non-root user to test with + testUser, err := findNonRootUser() + if err != nil { + t.Skip("No suitable non-root user found for testing") + } + + // Verify the user actually exists by looking it up again + _, err = user.LookupId(testUser.Uid) + if err != nil { + t.Skipf("Test user %s (UID %s) does not exist on this system: %v", testUser.Username, testUser.Uid, err) + } + + uid64, err := strconv.ParseUint(testUser.Uid, 10, 32) + require.NoError(t, err) + targetUID := uint32(uid64) + + gid64, err := strconv.ParseUint(testUser.Gid, 10, 32) + require.NoError(t, err) + targetGID := uint32(gid64) + + // Test in a child process to avoid affecting the test runner + if os.Getenv("TEST_PRIVILEGE_DROP") == "1" { + pd := NewPrivilegeDropper() + + // This should succeed + err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID}) + require.NoError(t, err) + + // Verify we are now running as the target user + currentUID := uint32(os.Geteuid()) + currentGID := uint32(os.Getegid()) + + assert.Equal(t, targetUID, currentUID, "UID should match target") + assert.Equal(t, targetGID, currentGID, "GID should match target") + assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root") + assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group") + + return + } + + // Fork a child process to test privilege dropping + cmd := os.Args[0] + args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"} + + env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1") + + execCmd := exec.Command(cmd, args...) + execCmd.Env = env + + err = execCmd.Run() + require.NoError(t, err, "Child process should succeed") +} + +// findNonRootUser finds any non-root user on the system for testing +func findNonRootUser() (*user.User, error) { + // Try common non-root users, but avoid "nobody" on macOS due to negative UID issues + commonUsers := []string{"daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"} + + for _, username := range commonUsers { + if u, err := user.Lookup(username); err == nil { + // Parse as signed integer first to handle negative UIDs + uid64, err := strconv.ParseInt(u.Uid, 10, 32) + if err != nil { + continue + } + // Skip negative UIDs (like nobody=-2 on macOS) and root + if uid64 > 0 && uid64 != 0 { + return u, nil + } + } + } + + // If no common users found, try to find any regular user with UID > 100 + // This helps on macOS where regular users start at UID 501 + allUsers := []string{"vma", "user", "test", "admin"} + for _, username := range allUsers { + if u, err := user.Lookup(username); err == nil { + uid64, err := strconv.ParseInt(u.Uid, 10, 32) + if err != nil { + continue + } + if uid64 > 100 { // Regular user + return u, nil + } + } + } + + // If no common users found, return an error + return nil, fmt.Errorf("no suitable non-root user found on this system") +} + +func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) { + pd := NewPrivilegeDropper() + currentUID := uint32(os.Geteuid()) + + if currentUID == 0 { + // When running as root, test that root can create commands for any user + config := ExecutorConfig{ + UID: 1000, // Target non-root user + GID: 1000, + Groups: []uint32{1000}, + WorkingDir: "/tmp", + Shell: "/bin/sh", + Command: "echo test", + } + + cmd, err := pd.CreateExecutorCommand(context.Background(), config) + assert.NoError(t, err, "Root should be able to create commands for any user") + assert.NotNil(t, cmd) + } else { + // When running as non-root, test that we can't drop to a different user + config := ExecutorConfig{ + UID: 0, // Try to target root + GID: 0, + Groups: []uint32{0}, + WorkingDir: "/tmp", + Shell: "/bin/sh", + Command: "echo test", + } + + _, err := pd.CreateExecutorCommand(context.Background(), config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "cannot drop privileges") + } +} diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go new file mode 100644 index 000000000..d3504e056 --- /dev/null +++ b/client/ssh/server/executor_windows.go @@ -0,0 +1,570 @@ +//go:build windows + +package server + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "os/user" + "strings" + "syscall" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + ExitCodeSuccess = 0 + ExitCodeLogonFail = 10 + ExitCodeCreateProcessFail = 11 + ExitCodeWorkingDirFail = 12 + ExitCodeShellExecFail = 13 + ExitCodeValidationFail = 14 +) + +type WindowsExecutorConfig struct { + Username string + Domain string + WorkingDir string + Shell string + Command string + Args []string + Interactive bool + Pty bool + PtyWidth int + PtyHeight int +} + +type PrivilegeDropper struct{} + +func NewPrivilegeDropper() *PrivilegeDropper { + return &PrivilegeDropper{} +} + +var ( + advapi32 = windows.NewLazyDLL("advapi32.dll") + procAllocateLocallyUniqueId = advapi32.NewProc("AllocateLocallyUniqueId") +) + +const ( + logon32LogonNetwork = 3 // Network logon - no password required for authenticated users + + // Common error messages + commandFlag = "-Command" + closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials + convertUsernameError = "convert username to UTF16: %w" + convertDomainError = "convert domain to UTF16: %w" +) + +// CreateWindowsExecutorCommand creates a Windows command with privilege dropping. +// The caller must close the returned token handle after starting the process. +func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) { + if config.Username == "" { + return nil, 0, errors.New("username cannot be empty") + } + if config.Shell == "" { + return nil, 0, errors.New("shell cannot be empty") + } + + shell := config.Shell + + var shellArgs []string + if config.Command != "" { + shellArgs = []string{shell, commandFlag, config.Command} + } else { + shellArgs = []string{shell} + } + + log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) + + cmd, token, err := pd.CreateWindowsProcessAsUser( + ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir) + if err != nil { + return nil, 0, fmt.Errorf("create Windows process as user: %w", err) + } + + return cmd, token, nil +} + +const ( + // StatusSuccess represents successful LSA operation + StatusSuccess = 0 + + // KerbS4ULogonType message type for domain users with Kerberos + KerbS4ULogonType = 12 + // Msv10s4ulogontype message type for local users with MSV1_0 + Msv10s4ulogontype = 12 + + // MicrosoftKerberosNameA is the authentication package name for Kerberos + MicrosoftKerberosNameA = "Kerberos" + // Msv10packagename is the authentication package name for MSV1_0 + Msv10packagename = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0" + + NameSamCompatible = 2 + NameUserPrincipal = 8 + NameCanonical = 7 + + maxUPNLen = 1024 +) + +// kerbS4ULogon structure for S4U authentication (domain users) +type kerbS4ULogon struct { + MessageType uint32 + Flags uint32 + ClientUpn unicodeString + ClientRealm unicodeString +} + +// msv10s4ulogon structure for S4U authentication (local users) +type msv10s4ulogon struct { + MessageType uint32 + Flags uint32 + UserPrincipalName unicodeString + DomainName unicodeString +} + +// unicodeString structure +type unicodeString struct { + Length uint16 + MaximumLength uint16 + Buffer *uint16 +} + +// lsaString structure +type lsaString struct { + Length uint16 + MaximumLength uint16 + Buffer *byte +} + +// tokenSource structure +type tokenSource struct { + SourceName [8]byte + SourceIdentifier windows.LUID +} + +// quotaLimits structure +type quotaLimits struct { + PagedPoolLimit uint32 + NonPagedPoolLimit uint32 + MinimumWorkingSetSize uint32 + MaximumWorkingSetSize uint32 + PagefileLimit uint32 + TimeLimit int64 +} + +var ( + secur32 = windows.NewLazyDLL("secur32.dll") + procLsaRegisterLogonProcess = secur32.NewProc("LsaRegisterLogonProcess") + procLsaLookupAuthenticationPackage = secur32.NewProc("LsaLookupAuthenticationPackage") + procLsaLogonUser = secur32.NewProc("LsaLogonUser") + procLsaFreeReturnBuffer = secur32.NewProc("LsaFreeReturnBuffer") + procLsaDeregisterLogonProcess = secur32.NewProc("LsaDeregisterLogonProcess") + procTranslateNameW = secur32.NewProc("TranslateNameW") +) + +// newLsaString creates an LsaString from a Go string +func newLsaString(s string) lsaString { + b := append([]byte(s), 0) + return lsaString{ + Length: uint16(len(s)), + MaximumLength: uint16(len(b)), + Buffer: &b[0], + } +} + +// generateS4UUserToken creates a Windows token using S4U authentication +// This is the exact approach OpenSSH for Windows uses for public key authentication +func generateS4UUserToken(username, domain string) (windows.Handle, error) { + userCpn := buildUserCpn(username, domain) + + pd := NewPrivilegeDropper() + isDomainUser := !pd.isLocalUser(domain) + + lsaHandle, err := initializeLsaConnection() + if err != nil { + return 0, err + } + defer cleanupLsaConnection(lsaHandle) + + authPackageId, err := lookupAuthenticationPackage(lsaHandle, isDomainUser) + if err != nil { + return 0, err + } + + logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser) + if err != nil { + return 0, err + } + + return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) +} + +// buildUserCpn constructs the user principal name +func buildUserCpn(username, domain string) string { + if domain != "" && domain != "." { + return fmt.Sprintf(`%s\%s`, domain, username) + } + return username +} + +// initializeLsaConnection establishes connection to LSA +func initializeLsaConnection() (windows.Handle, error) { + + processName := newLsaString("NetBird") + var mode uint32 + var lsaHandle windows.Handle + ret, _, _ := procLsaRegisterLogonProcess.Call( + uintptr(unsafe.Pointer(&processName)), + uintptr(unsafe.Pointer(&lsaHandle)), + uintptr(unsafe.Pointer(&mode)), + ) + if ret != StatusSuccess { + return 0, fmt.Errorf("LsaRegisterLogonProcess: 0x%x", ret) + } + + return lsaHandle, nil +} + +// cleanupLsaConnection closes the LSA connection +func cleanupLsaConnection(lsaHandle windows.Handle) { + if ret, _, _ := procLsaDeregisterLogonProcess.Call(uintptr(lsaHandle)); ret != StatusSuccess { + log.Debugf("LsaDeregisterLogonProcess failed: 0x%x", ret) + } +} + +// lookupAuthenticationPackage finds the correct authentication package +func lookupAuthenticationPackage(lsaHandle windows.Handle, isDomainUser bool) (uint32, error) { + var authPackageName lsaString + if isDomainUser { + authPackageName = newLsaString(MicrosoftKerberosNameA) + } else { + authPackageName = newLsaString(Msv10packagename) + } + + var authPackageId uint32 + ret, _, _ := procLsaLookupAuthenticationPackage.Call( + uintptr(lsaHandle), + uintptr(unsafe.Pointer(&authPackageName)), + uintptr(unsafe.Pointer(&authPackageId)), + ) + if ret != StatusSuccess { + return 0, fmt.Errorf("LsaLookupAuthenticationPackage: 0x%x", ret) + } + + return authPackageId, nil +} + +// lookupPrincipalName converts DOMAIN\username to username@domain.fqdn (UPN format) +func lookupPrincipalName(username, domain string) (string, error) { + samAccountName := fmt.Sprintf(`%s\%s`, domain, username) + samAccountNameUtf16, err := windows.UTF16PtrFromString(samAccountName) + if err != nil { + return "", fmt.Errorf("convert SAM account name to UTF-16: %w", err) + } + + upnBuf := make([]uint16, maxUPNLen+1) + upnSize := uint32(len(upnBuf)) + + ret, _, _ := procTranslateNameW.Call( + uintptr(unsafe.Pointer(samAccountNameUtf16)), + uintptr(NameSamCompatible), + uintptr(NameUserPrincipal), + uintptr(unsafe.Pointer(&upnBuf[0])), + uintptr(unsafe.Pointer(&upnSize)), + ) + + if ret != 0 { + upn := windows.UTF16ToString(upnBuf[:upnSize]) + log.Debugf("Translated %s to explicit UPN: %s", samAccountName, upn) + return upn, nil + } + + upnSize = uint32(len(upnBuf)) + ret, _, _ = procTranslateNameW.Call( + uintptr(unsafe.Pointer(samAccountNameUtf16)), + uintptr(NameSamCompatible), + uintptr(NameCanonical), + uintptr(unsafe.Pointer(&upnBuf[0])), + uintptr(unsafe.Pointer(&upnSize)), + ) + + if ret != 0 { + canonical := windows.UTF16ToString(upnBuf[:upnSize]) + slashIdx := strings.IndexByte(canonical, '/') + if slashIdx > 0 { + fqdn := canonical[:slashIdx] + upn := fmt.Sprintf("%s@%s", username, fqdn) + log.Debugf("Translated %s to implicit UPN: %s (from canonical: %s)", samAccountName, upn, canonical) + return upn, nil + } + } + + log.Debugf("Could not translate %s to UPN, using SAM format", samAccountName) + return samAccountName, nil +} + +// prepareS4ULogonStructure creates the appropriate S4U logon structure +func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { + if isDomainUser { + return prepareDomainS4ULogon(username, domain) + } + return prepareLocalS4ULogon(username) +} + +// prepareDomainS4ULogon creates S4U logon structure for domain users +func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) { + upn, err := lookupPrincipalName(username, domain) + if err != nil { + return nil, 0, fmt.Errorf("lookup principal name: %w", err) + } + + log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) + + upnUtf16, err := windows.UTF16FromString(upn) + if err != nil { + return nil, 0, fmt.Errorf(convertUsernameError, err) + } + + structSize := unsafe.Sizeof(kerbS4ULogon{}) + upnByteSize := len(upnUtf16) * 2 + logonInfoSize := structSize + uintptr(upnByteSize) + + buffer := make([]byte, logonInfoSize) + logonInfo := unsafe.Pointer(&buffer[0]) + + s4uLogon := (*kerbS4ULogon)(logonInfo) + s4uLogon.MessageType = KerbS4ULogonType + s4uLogon.Flags = 0 + + upnOffset := structSize + upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset)) + copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16) + + s4uLogon.ClientUpn = unicodeString{ + Length: uint16((len(upnUtf16) - 1) * 2), + MaximumLength: uint16(len(upnUtf16) * 2), + Buffer: upnBuffer, + } + s4uLogon.ClientRealm = unicodeString{} + + return logonInfo, logonInfoSize, nil +} + +// prepareLocalS4ULogon creates S4U logon structure for local users +func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { + log.Debugf("using Msv1_0S4ULogon for local user: %s", username) + + usernameUtf16, err := windows.UTF16FromString(username) + if err != nil { + return nil, 0, fmt.Errorf(convertUsernameError, err) + } + + domainUtf16, err := windows.UTF16FromString(".") + if err != nil { + return nil, 0, fmt.Errorf(convertDomainError, err) + } + + structSize := unsafe.Sizeof(msv10s4ulogon{}) + usernameByteSize := len(usernameUtf16) * 2 + domainByteSize := len(domainUtf16) * 2 + logonInfoSize := structSize + uintptr(usernameByteSize) + uintptr(domainByteSize) + + buffer := make([]byte, logonInfoSize) + logonInfo := unsafe.Pointer(&buffer[0]) + + s4uLogon := (*msv10s4ulogon)(logonInfo) + s4uLogon.MessageType = Msv10s4ulogontype + s4uLogon.Flags = 0x0 + + usernameOffset := structSize + usernameBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + usernameOffset)) + copy((*[256]uint16)(unsafe.Pointer(usernameBuffer))[:len(usernameUtf16)], usernameUtf16) + + s4uLogon.UserPrincipalName = unicodeString{ + Length: uint16((len(usernameUtf16) - 1) * 2), + MaximumLength: uint16(len(usernameUtf16) * 2), + Buffer: usernameBuffer, + } + + domainOffset := usernameOffset + uintptr(usernameByteSize) + domainBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + domainOffset)) + copy((*[16]uint16)(unsafe.Pointer(domainBuffer))[:len(domainUtf16)], domainUtf16) + + s4uLogon.DomainName = unicodeString{ + Length: uint16((len(domainUtf16) - 1) * 2), + MaximumLength: uint16(len(domainUtf16) * 2), + Buffer: domainBuffer, + } + + return logonInfo, logonInfoSize, nil +} + +// performS4ULogon executes the S4U logon operation +func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { + var tokenSource tokenSource + copy(tokenSource.SourceName[:], "netbird") + if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 { + log.Debugf("AllocateLocallyUniqueId failed") + } + + originName := newLsaString("netbird") + + var profile uintptr + var profileSize uint32 + var logonId windows.LUID + var token windows.Handle + var quotas quotaLimits + var subStatus int32 + + ret, _, _ := procLsaLogonUser.Call( + uintptr(lsaHandle), + uintptr(unsafe.Pointer(&originName)), + logon32LogonNetwork, + uintptr(authPackageId), + uintptr(logonInfo), + logonInfoSize, + 0, + uintptr(unsafe.Pointer(&tokenSource)), + uintptr(unsafe.Pointer(&profile)), + uintptr(unsafe.Pointer(&profileSize)), + uintptr(unsafe.Pointer(&logonId)), + uintptr(unsafe.Pointer(&token)), + uintptr(unsafe.Pointer("as)), + uintptr(unsafe.Pointer(&subStatus)), + ) + + if profile != 0 { + if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess { + log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) + } + } + + if ret != StatusSuccess { + return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus) + } + + log.Debugf("created S4U %s token for user %s", + map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn) + return token, nil +} + +// createToken implements NetBird trust-based authentication using S4U +func (pd *PrivilegeDropper) createToken(username, domain string) (windows.Handle, error) { + fullUsername := buildUserCpn(username, domain) + + if err := userExists(fullUsername, username, domain); err != nil { + return 0, err + } + + isLocalUser := pd.isLocalUser(domain) + + if isLocalUser { + return pd.authenticateLocalUser(username, fullUsername) + } + return pd.authenticateDomainUser(username, domain, fullUsername) +} + +// userExists checks if the target useVerifier exists on the system +func userExists(fullUsername, username, domain string) error { + if _, err := lookupUser(fullUsername); err != nil { + log.Debugf("User %s not found: %v", fullUsername, err) + if domain != "" && domain != "." { + _, err = lookupUser(username) + } + if err != nil { + return fmt.Errorf("target user %s not found: %w", fullUsername, err) + } + } + return nil +} + +// isLocalUser determines if this is a local user vs domain user +func (pd *PrivilegeDropper) isLocalUser(domain string) bool { + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + + return domain == "" || domain == "." || + strings.EqualFold(domain, hostname) +} + +// authenticateLocalUser handles authentication for local users +func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) { + log.Debugf("using S4U authentication for local user %s", fullUsername) + token, err := generateS4UUserToken(username, ".") + if err != nil { + return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err) + } + return token, nil +} + +// authenticateDomainUser handles authentication for domain users +func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) { + log.Debugf("using S4U authentication for domain user %s", fullUsername) + token, err := generateS4UUserToken(username, domain) + if err != nil { + return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err) + } + log.Debugf("Successfully created S4U token for domain user %s", fullUsername) + return token, nil +} + +// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables). +// The caller must close the returned token handle after starting the process. +func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) { + token, err := pd.createToken(username, domain) + if err != nil { + return nil, 0, fmt.Errorf("user authentication: %w", err) + } + + defer func() { + if err := windows.CloseHandle(token); err != nil { + log.Debugf("close impersonation token: %v", err) + } + }() + + cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir) + if err != nil { + return nil, 0, err + } + + return cmd, primaryToken, nil +} + +// createProcessWithToken creates process with the specified token and executable path. +// The caller must close the returned token handle after starting the process. +func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) { + cmd := exec.CommandContext(ctx, executablePath, args[1:]...) + cmd.Dir = workingDir + + var primaryToken windows.Token + err := windows.DuplicateTokenEx( + sourceToken, + windows.TOKEN_ALL_ACCESS, + nil, + windows.SecurityIdentification, + windows.TokenPrimary, + &primaryToken, + ) + if err != nil { + return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err) + } + + cmd.SysProcAttr = &syscall.SysProcAttr{ + Token: syscall.Token(primaryToken), + } + + return cmd, primaryToken, nil +} + +// createSuCommand creates a command using su -l -c for privilege switching (Windows stub) +func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) { + return nil, fmt.Errorf("su command not available on Windows") +} diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go new file mode 100644 index 000000000..e22bdfb06 --- /dev/null +++ b/client/ssh/server/jwt_test.go @@ -0,0 +1,629 @@ +package server + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "runtime" + "strconv" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + cryptossh "golang.org/x/crypto/ssh" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/ssh/client" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/client/ssh/testutil" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" +) + +func TestJWTEnforcement(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT enforcement tests in short mode") + } + + // Set up SSH server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + t.Run("blocks_without_jwt", func(t *testing.T) { + jwtConfig := &JWTConfig{ + Issuer: "test-issuer", + Audience: "test-audience", + KeysLocation: "test-keys", + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) + if err != nil { + t.Logf("Detection failed: %v", err) + } + t.Logf("Detected server type: %s", serverType) + + config := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{}, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + _, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config) + assert.Error(t, err, "SSH connection should fail when JWT is required but not provided") + }) + + t.Run("allows_when_disabled", func(t *testing.T) { + serverConfigNoJWT := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + serverNoJWT := New(serverConfigNoJWT) + require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config") + serverNoJWT.SetAllowRootLogin(true) + + serverAddrNoJWT := StartTestServer(t, serverNoJWT) + defer require.NoError(t, serverNoJWT.Stop()) + + hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT) + require.NoError(t, err) + portNoJWT, err := strconv.Atoi(portStrNoJWT) + require.NoError(t, err) + + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT) + require.NoError(t, err) + assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType) + assert.False(t, serverType.RequiresJWT()) + + client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT) + require.NoError(t, err) + defer client.Close() + }) + +} + +// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL +func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) { + privateKey, jwksJSON := generateTestJWKS(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(jwksJSON); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + })) + + return server, privateKey, server.URL +} + +// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON +func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey := &privateKey.PublicKey + n := publicKey.N.Bytes() + e := publicKey.E + + jwk := nbjwt.JSONWebKey{ + Kty: "RSA", + Kid: "test-key-id", + Use: "sig", + N: base64RawURLEncode(n), + E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()), + } + + jwks := nbjwt.Jwks{ + Keys: []nbjwt.JSONWebKey{jwk}, + } + + jwksJSON, err := json.Marshal(jwks) + require.NoError(t, err) + + return privateKey, jwksJSON +} + +func base64RawURLEncode(data []byte) string { + return base64.RawURLEncoding.EncodeToString(data) +} + +// generateValidJWT creates a valid JWT token for testing +func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string { + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "test-key-id" + + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + return tokenString +} + +// connectWithNetBirdClient connects to SSH server using NetBird's SSH client +func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) { + t.Helper() + addr := net.JoinHostPort(host, strconv.Itoa(port)) + + ctx := context.Background() + return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{ + InsecureSkipVerify: true, + }) +} + +// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers +func TestJWTDetection(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT detection test in short mode") + } + + jwksServer, _, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + jwtConfig := &JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + dialer := &net.Dialer{Timeout: detection.Timeout} + serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) + require.NoError(t, err) + assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType) + assert.True(t, serverType.RequiresJWT()) +} + +func TestJWTFailClose(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT fail-close tests in short mode") + } + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + testCases := []struct { + name string + tokenClaims jwt.MapClaims + }{ + { + name: "blocks_token_missing_iat", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }, + }, + { + name: "blocks_token_missing_sub", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_missing_iss", + tokenClaims: jwt.MapClaims{ + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_missing_aud", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_wrong_issuer", + tokenClaims: jwt.MapClaims{ + "iss": "wrong-issuer", + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_token_wrong_audience", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": "wrong-audience", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + }, + }, + { + name: "blocks_expired_token", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(-time.Hour).Unix(), + "iat": time.Now().Add(-2 * time.Hour).Unix(), + }, + }, + { + name: "blocks_token_exceeding_max_age", + tokenClaims: jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Add(-2 * time.Hour).Unix(), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + jwtConfig := &JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + MaxTokenAge: 3600, + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims) + token.Header["kid"] = "test-key-id" + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + + config := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{ + cryptossh.Password(tokenString), + }, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config) + if conn != nil { + defer func() { + if err := conn.Close(); err != nil { + t.Logf("close connection: %v", err) + } + }() + } + + assert.Error(t, err, "Authentication should fail (fail-close)") + }) + } +} + +// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types +func TestJWTAuthentication(t *testing.T) { + if testing.Short() { + t.Skip("Skipping JWT authentication tests in short mode") + } + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + defer jwksServer.Close() + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + testCases := []struct { + name string + token string + wantAuthOK bool + setupServer func(*Server) + testOperation func(*testing.T, *cryptossh.Client, string) error + wantOpSuccess bool + }{ + { + name: "allows_shell_with_jwt", + token: "valid", + wantAuthOK: true, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + return session.Shell() + }, + wantOpSuccess: true, + }, + { + name: "rejects_invalid_token", + token: "invalid", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + output, err := session.CombinedOutput("echo test") + if err != nil { + t.Logf("Command output: %s", string(output)) + return err + } + return nil + }, + wantOpSuccess: false, + }, + { + name: "blocks_shell_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + output, err := session.CombinedOutput("echo test") + if err != nil { + t.Logf("Command output: %s", string(output)) + return err + } + return nil + }, + wantOpSuccess: false, + }, + { + name: "blocks_command_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + output, err := session.CombinedOutput("ls") + if err != nil { + t.Logf("Command output: %s", string(output)) + return err + } + return nil + }, + wantOpSuccess: false, + }, + { + name: "allows_sftp_with_jwt", + token: "valid", + wantAuthOK: true, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowSFTP(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + session.Stdout = io.Discard + session.Stderr = io.Discard + return session.RequestSubsystem("sftp") + }, + wantOpSuccess: true, + }, + { + name: "blocks_sftp_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowSFTP(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + session, err := conn.NewSession() + require.NoError(t, err) + defer session.Close() + + session.Stdout = io.Discard + session.Stderr = io.Discard + err = session.RequestSubsystem("sftp") + if err == nil { + err = session.Wait() + } + return err + }, + wantOpSuccess: false, + }, + { + name: "allows_port_forward_with_jwt", + token: "valid", + wantAuthOK: true, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowRemotePortForwarding(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + ln, err := conn.Listen("tcp", "127.0.0.1:0") + if ln != nil { + defer ln.Close() + } + return err + }, + wantOpSuccess: true, + }, + { + name: "blocks_port_forward_without_jwt", + token: "", + wantAuthOK: false, + setupServer: func(s *Server) { + s.SetAllowRootLogin(true) + s.SetAllowLocalPortForwarding(true) + }, + testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error { + ln, err := conn.Listen("tcp", "127.0.0.1:0") + if ln != nil { + defer ln.Close() + } + return err + }, + wantOpSuccess: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // TODO: Skip port forwarding tests on Windows - user switching not supported + // These features are tested on Linux/Unix platforms + if runtime.GOOS == "windows" && + (tc.name == "allows_port_forward_with_jwt" || + tc.name == "blocks_port_forward_without_jwt") { + t.Skip("Skipping port forwarding test on Windows - covered by Linux tests") + } + + jwtConfig := &JWTConfig{ + Issuer: issuer, + Audience: audience, + KeysLocation: jwksURL, + } + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: jwtConfig, + } + server := New(serverConfig) + if tc.setupServer != nil { + tc.setupServer(server) + } + + serverAddr := StartTestServer(t, server) + defer require.NoError(t, server.Stop()) + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + var authMethods []cryptossh.AuthMethod + if tc.token == "valid" { + token := generateValidJWT(t, privateKey, issuer, audience) + authMethods = []cryptossh.AuthMethod{ + cryptossh.Password(token), + } + } else if tc.token == "invalid" { + invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid" + authMethods = []cryptossh.AuthMethod{ + cryptossh.Password(invalidToken), + } + } + + config := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: authMethods, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config) + if tc.wantAuthOK { + require.NoError(t, err, "JWT authentication should succeed") + } else if err != nil { + t.Logf("Connection failed as expected: %v", err) + return + } + if conn != nil { + defer func() { + if err := conn.Close(); err != nil { + t.Logf("close connection: %v", err) + } + }() + } + + err = tc.testOperation(t, conn, serverAddr) + if tc.wantOpSuccess { + require.NoError(t, err, "Operation should succeed") + } else { + assert.Error(t, err, "Operation should fail") + } + }) + } +} diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go new file mode 100644 index 000000000..6138f9296 --- /dev/null +++ b/client/ssh/server/port_forwarding.go @@ -0,0 +1,386 @@ +package server + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" +) + +// SessionKey uniquely identifies an SSH session +type SessionKey string + +// ConnectionKey uniquely identifies a port forwarding connection within a session +type ConnectionKey string + +// ForwardKey uniquely identifies a port forwarding listener +type ForwardKey string + +// tcpipForwardMsg represents the structure for tcpip-forward SSH requests +type tcpipForwardMsg struct { + Host string + Port uint32 +} + +// SetAllowLocalPortForwarding configures local port forwarding +func (s *Server) SetAllowLocalPortForwarding(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowLocalPortForwarding = allow +} + +// SetAllowRemotePortForwarding configures remote port forwarding +func (s *Server) SetAllowRemotePortForwarding(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowRemotePortForwarding = allow +} + +// configurePortForwarding sets up port forwarding callbacks +func (s *Server) configurePortForwarding(server *ssh.Server) { + allowLocal := s.allowLocalPortForwarding + allowRemote := s.allowRemotePortForwarding + + server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool { + if !allowLocal { + log.Warnf("local port forwarding denied for %s from %s: disabled by configuration", + net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr()) + return false + } + + if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil { + log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err) + return false + } + + log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort) + return true + } + + server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + if !allowRemote { + log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration", + net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr()) + return false + } + + if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil { + log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err) + return false + } + + log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort) + return true + } + + log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote) +} + +// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations. +// Returns nil if allowed, error if denied. +func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error { + if ctx == nil { + return fmt.Errorf("%s port forwarding denied: no context", forwardType) + } + + username := ctx.User() + remoteAddr := "unknown" + if ctx.RemoteAddr() != nil { + remoteAddr = ctx.RemoteAddr().String() + } + + logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port}) + + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: username, + FeatureSupportsUserSwitch: false, + FeatureName: forwardType + " port forwarding", + }) + + if !result.Allowed { + return result.Error + } + + logger.Debugf("%s port forwarding allowed: user %s validated (port %d)", + forwardType, result.User.Username, port) + + return nil +} + +// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding. +func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + logger := s.getRequestLogger(ctx) + + if !s.isRemotePortForwardingAllowed() { + logger.Warnf("tcpip-forward request denied: remote port forwarding disabled") + return false, nil + } + + payload, err := s.parseTcpipForwardRequest(req) + if err != nil { + logger.Errorf("tcpip-forward unmarshal error: %v", err) + return false, nil + } + + if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil { + logger.Warnf("tcpip-forward denied: %v", err) + return false, nil + } + + logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port) + + sshConn, err := s.getSSHConnection(ctx) + if err != nil { + logger.Warnf("tcpip-forward request denied: %v", err) + return false, nil + } + + return s.setupDirectForward(ctx, logger, sshConn, payload) +} + +// cancelTcpipForwardHandler handles cancel-tcpip-forward requests. +func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) { + logger := s.getRequestLogger(ctx) + + var payload tcpipForwardMsg + if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil { + logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err) + return false, nil + } + + key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + if s.removeRemoteForwardListener(key) { + logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port) + return true, nil + } + + logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port) + return false, nil +} + +// handleRemoteForwardListener handles incoming connections for remote port forwarding. +func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) { + log.Debugf("starting remote forward listener handler for %s:%d", host, port) + + defer func() { + log.Debugf("cleaning up remote forward listener for %s:%d", host, port) + if err := ln.Close(); err != nil { + log.Debugf("remote forward listener close error: %v", err) + } else { + log.Debugf("remote forward listener closed successfully for %s:%d", host, port) + } + }() + + acceptChan := make(chan acceptResult, 1) + + go func() { + for { + conn, err := ln.Accept() + select { + case acceptChan <- acceptResult{conn: conn, err: err}: + if err != nil { + return + } + case <-ctx.Done(): + return + } + } + }() + + for { + select { + case result := <-acceptChan: + if result.err != nil { + log.Debugf("remote forward accept error: %v", result.err) + return + } + go s.handleRemoteForwardConnection(ctx, result.conn, host, port) + case <-ctx.Done(): + log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port) + return + } + } +} + +// getRequestLogger creates a logger with user and remote address context +func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry { + remoteAddr := "unknown" + username := "unknown" + if ctx != nil { + if ctx.RemoteAddr() != nil { + remoteAddr = ctx.RemoteAddr().String() + } + username = ctx.User() + } + return log.WithFields(log.Fields{"user": username, "remote": remoteAddr}) +} + +// isRemotePortForwardingAllowed checks if remote port forwarding is enabled +func (s *Server) isRemotePortForwardingAllowed() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.allowRemotePortForwarding +} + +// parseTcpipForwardRequest parses the SSH request payload +func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { + var payload tcpipForwardMsg + err := cryptossh.Unmarshal(req.Payload, &payload) + return &payload, err +} + +// getSSHConnection extracts SSH connection from context +func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) { + if ctx == nil { + return nil, fmt.Errorf("no context") + } + sshConnValue := ctx.Value(ssh.ContextKeyConn) + if sshConnValue == nil { + return nil, fmt.Errorf("no SSH connection in context") + } + sshConn, ok := sshConnValue.(*cryptossh.ServerConn) + if !ok || sshConn == nil { + return nil, fmt.Errorf("invalid SSH connection in context") + } + return sshConn, nil +} + +// setupDirectForward sets up a direct port forward +func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) { + bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10)) + + ln, err := net.Listen("tcp", bindAddr) + if err != nil { + logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err) + return false, nil + } + + actualPort := payload.Port + if payload.Port == 0 { + tcpAddr := ln.Addr().(*net.TCPAddr) + actualPort = uint32(tcpAddr.Port) + logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host) + } + + key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + s.storeRemoteForwardListener(key, ln) + + s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String()) + go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort) + + response := make([]byte, 4) + binary.BigEndian.PutUint32(response, actualPort) + + logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort) + return true, response +} + +// acceptResult holds the result of a listener Accept() call +type acceptResult struct { + conn net.Conn + err error +} + +// handleRemoteForwardConnection handles a single remote port forwarding connection +func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) { + sessionKey := s.findSessionKeyByContext(ctx) + connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port) + logger := log.WithFields(log.Fields{ + "session": sessionKey, + "conn": connID, + }) + + defer func() { + if err := conn.Close(); err != nil { + logger.Debugf("connection close error: %v", err) + } + }() + + sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn) + if sshConn == nil { + logger.Debugf("remote forward: no SSH connection in context") + return + } + + remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr) + if !ok { + logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr()) + return + } + + channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger) + if err != nil { + logger.Debugf("open forward channel: %v", err) + return + } + + s.proxyForwardConnection(ctx, logger, conn, channel) +} + +// openForwardChannel creates an SSH forwarded-tcpip channel +func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) { + logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port) + + payload := struct { + ConnectedAddress string + ConnectedPort uint32 + OriginatorAddress string + OriginatorPort uint32 + }{ + ConnectedAddress: host, + ConnectedPort: port, + OriginatorAddress: remoteAddr.IP.String(), + OriginatorPort: uint32(remoteAddr.Port), + } + + channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload)) + if err != nil { + return nil, fmt.Errorf("open SSH channel: %w", err) + } + + go cryptossh.DiscardRequests(reqs) + return channel, nil +} + +// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel +func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) { + done := make(chan struct{}, 2) + + go func() { + if _, err := io.Copy(channel, conn); err != nil { + logger.Debugf("copy error (conn->channel): %v", err) + } + done <- struct{}{} + }() + + go func() { + if _, err := io.Copy(conn, channel); err != nil { + logger.Debugf("copy error (channel->conn): %v", err) + } + done <- struct{}{} + }() + + select { + case <-ctx.Done(): + logger.Debugf("session ended, closing connections") + case <-done: + // First copy finished, wait for second copy or context cancellation + select { + case <-ctx.Done(): + logger.Debugf("session ended, closing connections") + case <-done: + } + } + + if err := channel.Close(); err != nil { + logger.Debugf("channel close error: %v", err) + } + if err := conn.Close(); err != nil { + logger.Debugf("connection close error: %v", err) + } +} diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go new file mode 100644 index 000000000..44612532b --- /dev/null +++ b/client/ssh/server/server.go @@ -0,0 +1,712 @@ +package server + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/gliderlabs/ssh" + gojwt "github.com/golang-jwt/jwt/v5" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" + "golang.org/x/exp/maps" + "golang.zx2c4.com/wireguard/tun/netstack" + + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/auth/jwt" + "github.com/netbirdio/netbird/version" +) + +// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server +const DefaultSSHPort = 22 + +// InternalSSHPort is the port SSH server listens on and is redirected to +const InternalSSHPort = 22022 + +const ( + errWriteSession = "write session error: %v" + errExitSession = "exit session error: %v" + + msgPrivilegedUserDisabled = "privileged user login is disabled" + + // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server + DefaultJWTMaxTokenAge = 5 * 60 +) + +var ( + ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled) + ErrUserNotFound = errors.New("user not found") +) + +// PrivilegedUserError represents an error when privileged user login is disabled +type PrivilegedUserError struct { + Username string +} + +func (e *PrivilegedUserError) Error() string { + return fmt.Sprintf("%s for user: %s", msgPrivilegedUserDisabled, e.Username) +} + +func (e *PrivilegedUserError) Is(target error) bool { + return target == ErrPrivilegedUserDisabled +} + +// UserNotFoundError represents an error when a user cannot be found +type UserNotFoundError struct { + Username string + Cause error +} + +func (e *UserNotFoundError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("user %s not found: %v", e.Username, e.Cause) + } + return fmt.Sprintf("user %s not found", e.Username) +} + +func (e *UserNotFoundError) Is(target error) bool { + return target == ErrUserNotFound +} + +func (e *UserNotFoundError) Unwrap() error { + return e.Cause +} + +// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors +func logSessionExitError(logger *log.Entry, err error) { + if err != nil && !errors.Is(err, io.EOF) { + logger.Warnf(errExitSession, err) + } +} + +// safeLogCommand returns a safe representation of the command for logging +func safeLogCommand(cmd []string) string { + if len(cmd) == 0 { + return "" + } + if len(cmd) == 1 { + return cmd[0] + } + return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1) +} + +type sshConnectionState struct { + hasActivePortForward bool + username string + remoteAddr string +} + +type authKey string + +func newAuthKey(username string, remoteAddr net.Addr) authKey { + return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String())) +} + +type Server struct { + sshServer *ssh.Server + mu sync.RWMutex + hostKeyPEM []byte + sessions map[SessionKey]ssh.Session + sessionCancels map[ConnectionKey]context.CancelFunc + sessionJWTUsers map[SessionKey]string + pendingAuthJWT map[authKey]string + + allowLocalPortForwarding bool + allowRemotePortForwarding bool + allowRootLogin bool + allowSFTP bool + jwtEnabled bool + + netstackNet *netstack.Net + + wgAddress wgaddr.Address + + remoteForwardListeners map[ForwardKey]net.Listener + sshConnections map[*cryptossh.ServerConn]*sshConnectionState + + jwtValidator *jwt.Validator + jwtExtractor *jwt.ClaimsExtractor + jwtConfig *JWTConfig + + suSupportsPty bool +} + +type JWTConfig struct { + Issuer string + Audience string + KeysLocation string + MaxTokenAge int64 +} + +// Config contains all SSH server configuration options +type Config struct { + // JWT authentication configuration. If nil, JWT authentication is disabled + JWT *JWTConfig + + // HostKey is the SSH server host key in PEM format + HostKeyPEM []byte +} + +// SessionInfo contains information about an active SSH session +type SessionInfo struct { + Username string + RemoteAddress string + Command string + JWTUsername string +} + +// New creates an SSH server instance with the provided host key and optional JWT configuration +// If jwtConfig is nil, JWT authentication is disabled +func New(config *Config) *Server { + s := &Server{ + mu: sync.RWMutex{}, + hostKeyPEM: config.HostKeyPEM, + sessions: make(map[SessionKey]ssh.Session), + sessionJWTUsers: make(map[SessionKey]string), + pendingAuthJWT: make(map[authKey]string), + remoteForwardListeners: make(map[ForwardKey]net.Listener), + sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), + jwtEnabled: config.JWT != nil, + jwtConfig: config.JWT, + } + + return s +} + +// Start runs the SSH server +func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.sshServer != nil { + return errors.New("SSH server is already running") + } + + s.suSupportsPty = s.detectSuPtySupport(ctx) + + ln, addrDesc, err := s.createListener(ctx, addr) + if err != nil { + return fmt.Errorf("create listener: %w", err) + } + + sshServer, err := s.createSSHServer(ln.Addr()) + if err != nil { + s.closeListener(ln) + return fmt.Errorf("create SSH server: %w", err) + } + + s.sshServer = sshServer + log.Infof("SSH server started on %s", addrDesc) + + go func() { + if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Errorf("SSH server error: %v", err) + } + }() + return nil +} + +func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) { + if s.netstackNet != nil { + ln, err := s.netstackNet.ListenTCPAddrPort(addr) + if err != nil { + return nil, "", fmt.Errorf("listen on netstack: %w", err) + } + return ln, fmt.Sprintf("netstack %s", addr), nil + } + + tcpAddr := net.TCPAddrFromAddrPort(addr) + lc := net.ListenConfig{} + ln, err := lc.Listen(ctx, "tcp", tcpAddr.String()) + if err != nil { + return nil, "", fmt.Errorf("listen: %w", err) + } + return ln, addr.String(), nil +} + +func (s *Server) closeListener(ln net.Listener) { + if ln == nil { + return + } + if err := ln.Close(); err != nil { + log.Debugf("listener close error: %v", err) + } +} + +// Stop closes the SSH server +func (s *Server) Stop() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.sshServer == nil { + return nil + } + + if err := s.sshServer.Close(); err != nil { + log.Debugf("close SSH server: %v", err) + } + + s.sshServer = nil + + maps.Clear(s.sessions) + maps.Clear(s.sessionJWTUsers) + maps.Clear(s.pendingAuthJWT) + maps.Clear(s.sshConnections) + + for _, cancelFunc := range s.sessionCancels { + cancelFunc() + } + maps.Clear(s.sessionCancels) + + for _, listener := range s.remoteForwardListeners { + if err := listener.Close(); err != nil { + log.Debugf("close remote forward listener: %v", err) + } + } + maps.Clear(s.remoteForwardListeners) + + return nil +} + +// GetStatus returns the current status of the SSH server and active sessions +func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { + s.mu.RLock() + defer s.mu.RUnlock() + + enabled = s.sshServer != nil + + for sessionKey, session := range s.sessions { + cmd := "" + if len(session.Command()) > 0 { + cmd = safeLogCommand(session.Command()) + } + + jwtUsername := s.sessionJWTUsers[sessionKey] + + sessions = append(sessions, SessionInfo{ + Username: session.User(), + RemoteAddress: session.RemoteAddr().String(), + Command: cmd, + JWTUsername: jwtUsername, + }) + } + + return enabled, sessions +} + +// SetNetstackNet sets the netstack network for userspace networking +func (s *Server) SetNetstackNet(net *netstack.Net) { + s.mu.Lock() + defer s.mu.Unlock() + s.netstackNet = net +} + +// SetNetworkValidation configures network-based connection filtering +func (s *Server) SetNetworkValidation(addr wgaddr.Address) { + s.mu.Lock() + defer s.mu.Unlock() + s.wgAddress = addr +} + +// ensureJWTValidator initializes the JWT validator and extractor if not already initialized +func (s *Server) ensureJWTValidator() error { + s.mu.RLock() + if s.jwtValidator != nil && s.jwtExtractor != nil { + s.mu.RUnlock() + return nil + } + config := s.jwtConfig + s.mu.RUnlock() + + if config == nil { + return fmt.Errorf("JWT config not set") + } + + log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience) + + validator := jwt.NewValidator( + config.Issuer, + []string{config.Audience}, + config.KeysLocation, + true, + ) + + extractor := jwt.NewClaimsExtractor( + jwt.WithAudience(config.Audience), + ) + + s.mu.Lock() + defer s.mu.Unlock() + + if s.jwtValidator != nil && s.jwtExtractor != nil { + return nil + } + + s.jwtValidator = validator + s.jwtExtractor = extractor + + log.Infof("JWT validator initialized successfully") + return nil +} + +func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) { + s.mu.RLock() + jwtValidator := s.jwtValidator + jwtConfig := s.jwtConfig + s.mu.RUnlock() + + if jwtValidator == nil { + return nil, fmt.Errorf("JWT validator not initialized") + } + + token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString) + if err != nil { + if jwtConfig != nil { + if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil { + return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w", + jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err) + } + } + return nil, fmt.Errorf("validate token: %w", err) + } + + if err := s.checkTokenAge(token, jwtConfig); err != nil { + return nil, err + } + + return token, nil +} + +func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error { + if jwtConfig == nil { + return nil + } + + maxTokenAge := jwtConfig.MaxTokenAge + if maxTokenAge <= 0 { + maxTokenAge = DefaultJWTMaxTokenAge + } + + claims, ok := token.Claims.(gojwt.MapClaims) + if !ok { + userID := extractUserID(token) + return fmt.Errorf("token has invalid claims format (user=%s)", userID) + } + + iat, ok := claims["iat"].(float64) + if !ok { + userID := extractUserID(token) + return fmt.Errorf("token missing iat claim (user=%s)", userID) + } + + issuedAt := time.Unix(int64(iat), 0) + tokenAge := time.Since(issuedAt) + maxAge := time.Duration(maxTokenAge) * time.Second + if tokenAge > maxAge { + userID := getUserIDFromClaims(claims) + return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge) + } + + return nil +} + +func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) { + s.mu.RLock() + jwtExtractor := s.jwtExtractor + s.mu.RUnlock() + + if jwtExtractor == nil { + userID := extractUserID(token) + return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID) + } + + userAuth, err := jwtExtractor.ToUserAuth(token) + if err != nil { + userID := extractUserID(token) + return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err) + } + + if !s.hasSSHAccess(&userAuth) { + return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId) + } + + return &userAuth, nil +} + +func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool { + return userAuth.UserId != "" +} + +func extractUserID(token *gojwt.Token) string { + if token == nil { + return "unknown" + } + claims, ok := token.Claims.(gojwt.MapClaims) + if !ok { + return "unknown" + } + return getUserIDFromClaims(claims) +} + +func getUserIDFromClaims(claims gojwt.MapClaims) string { + if sub, ok := claims["sub"].(string); ok && sub != "" { + return sub + } + if userID, ok := claims["user_id"].(string); ok && userID != "" { + return userID + } + if email, ok := claims["email"].(string); ok && email != "" { + return email + } + return "unknown" +} + +func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) { + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid token format") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("decode payload: %w", err) + } + + var claims map[string]interface{} + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, fmt.Errorf("parse claims: %w", err) + } + + return claims, nil +} + +func (s *Server) passwordHandler(ctx ssh.Context, password string) bool { + if err := s.ensureJWTValidator(); err != nil { + log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + return false + } + + token, err := s.validateJWTToken(password) + if err != nil { + log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + return false + } + + userAuth, err := s.extractAndValidateUser(token) + if err != nil { + log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) + return false + } + + key := newAuthKey(ctx.User(), ctx.RemoteAddr()) + s.mu.Lock() + s.pendingAuthJWT[key] = userAuth.UserId + s.mu.Unlock() + + log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr()) + return true +} + +func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) { + s.mu.Lock() + defer s.mu.Unlock() + + if state, exists := s.sshConnections[sshConn]; exists { + state.hasActivePortForward = true + } else { + s.sshConnections[sshConn] = &sshConnectionState{ + hasActivePortForward: true, + username: username, + remoteAddr: remoteAddr, + } + } +} + +func (s *Server) connectionCloseHandler(conn net.Conn, err error) { + // We can't extract the SSH connection from net.Conn directly + // Connection cleanup will happen during session cleanup or via timeout + log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err) +} + +func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { + if ctx == nil { + return "unknown" + } + + // Try to match by SSH connection + sshConn := ctx.Value(ssh.ContextKeyConn) + if sshConn == nil { + return "unknown" + } + + s.mu.RLock() + defer s.mu.RUnlock() + + // Look through sessions to find one with matching connection + for sessionKey, session := range s.sessions { + if session.Context().Value(ssh.ContextKeyConn) == sshConn { + return sessionKey + } + } + + // If no session found, this might be during early connection setup + // Return a temporary key that we'll fix up later + if ctx.User() != "" && ctx.RemoteAddr() != nil { + tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String())) + log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey) + return tempKey + } + + return "unknown" +} + +func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { + s.mu.RLock() + netbirdNetwork := s.wgAddress.Network + localIP := s.wgAddress.IP + s.mu.RUnlock() + + if !netbirdNetwork.IsValid() || !localIP.IsValid() { + return conn + } + + remoteAddr := conn.RemoteAddr() + tcpAddr, ok := remoteAddr.(*net.TCPAddr) + if !ok { + log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr) + return nil + } + + remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP) + if !ok { + log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP) + return nil + } + + // Block connections from our own IP (prevent local apps from connecting to ourselves) + if remoteIP == localIP { + log.Warnf("SSH connection rejected from own IP %s", remoteIP) + return nil + } + + if !netbirdNetwork.Contains(remoteIP) { + log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP) + return nil + } + + log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr) + return conn +} + +func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { + if err := enableUserSwitching(); err != nil { + log.Warnf("failed to enable user switching: %v", err) + } + + serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion()) + if s.jwtEnabled { + serverVersion += " " + detection.JWTRequiredMarker + } + + server := &ssh.Server{ + Addr: addr.String(), + Handler: s.sessionHandler, + SubsystemHandlers: map[string]ssh.SubsystemHandler{ + "sftp": s.sftpSubsystemHandler, + }, + HostSigners: []ssh.Signer{}, + ChannelHandlers: map[string]ssh.ChannelHandler{ + "session": ssh.DefaultSessionHandler, + "direct-tcpip": s.directTCPIPHandler, + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": s.tcpipForwardHandler, + "cancel-tcpip-forward": s.cancelTcpipForwardHandler, + }, + ConnCallback: s.connectionValidator, + ConnectionFailedCallback: s.connectionCloseHandler, + Version: serverVersion, + } + + if s.jwtEnabled { + server.PasswordHandler = s.passwordHandler + } + + hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM) + if err := server.SetOption(hostKeyPEM); err != nil { + return nil, fmt.Errorf("set host key: %w", err) + } + + s.configurePortForwarding(server) + return server, nil +} + +func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) { + s.mu.Lock() + defer s.mu.Unlock() + s.remoteForwardListeners[key] = ln +} + +func (s *Server) removeRemoteForwardListener(key ForwardKey) bool { + s.mu.Lock() + defer s.mu.Unlock() + + ln, exists := s.remoteForwardListeners[key] + if !exists { + return false + } + + delete(s.remoteForwardListeners, key) + if err := ln.Close(); err != nil { + log.Debugf("remote forward listener close error: %v", err) + } + + return true +} + +func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) { + var payload struct { + Host string + Port uint32 + OriginatorAddr string + OriginatorPort uint32 + } + + if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil { + if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil { + log.Debugf("channel reject error: %v", err) + } + return + } + + s.mu.RLock() + allowLocal := s.allowLocalPortForwarding + s.mu.RUnlock() + + if !allowLocal { + log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port) + _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled") + return + } + + // Check privilege requirements for the destination port + if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil { + log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) + _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges") + return + } + + log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) + + ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) +} diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go new file mode 100644 index 000000000..24e455025 --- /dev/null +++ b/client/ssh/server/server_config_test.go @@ -0,0 +1,394 @@ +package server + +import ( + "context" + "fmt" + "net" + "os/user" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/ssh" + sshclient "github.com/netbirdio/netbird/client/ssh/client" +) + +func TestServer_RootLoginRestriction(t *testing.T) { + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + tests := []struct { + name string + allowRoot bool + username string + expectError bool + description string + }{ + { + name: "root login allowed", + allowRoot: true, + username: "root", + expectError: false, + description: "Root login should succeed when allowed", + }, + { + name: "root login denied", + allowRoot: false, + username: "root", + expectError: true, + description: "Root login should fail when disabled", + }, + { + name: "regular user login always allowed", + allowRoot: false, + username: "testuser", + expectError: false, + description: "Regular user login should work regardless of root setting", + }, + } + + // Add Windows Administrator tests if on Windows + if runtime.GOOS == "windows" { + tests = append(tests, []struct { + name string + allowRoot bool + username string + expectError bool + description string + }{ + { + name: "Administrator login allowed", + allowRoot: true, + username: "Administrator", + expectError: false, + description: "Administrator login should succeed when allowed", + }, + { + name: "Administrator login denied", + allowRoot: false, + username: "Administrator", + expectError: true, + description: "Administrator login should fail when disabled", + }, + { + name: "administrator login denied (lowercase)", + allowRoot: false, + username: "administrator", + expectError: true, + description: "administrator login should fail when disabled (case insensitive)", + }, + }...) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock privileged environment to test root access controls + // Set up mock users based on platform + mockUsers := map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + "testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"), + } + + // Add Windows-specific users for Administrator tests + if runtime.GOOS == "windows" { + mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator") + mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator") + } + + cleanup := setupTestDependencies( + createTestUser("root", "0", "0", "/root"), // Running as root + nil, + runtime.GOOS, + 0, // euid 0 (root) + mockUsers, + nil, + ) + defer cleanup() + + // Create server with specific configuration + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(tt.allowRoot) + + // Test the userNameLookup method directly + user, err := server.userNameLookup(tt.username) + + if tt.expectError { + assert.Error(t, err, tt.description) + if tt.username == "root" || strings.ToLower(tt.username) == "administrator" { + // Check for appropriate error message based on platform capabilities + errorMsg := err.Error() + // Either privileged user restriction OR user switching limitation + hasPrivilegedError := strings.Contains(errorMsg, "privileged user") + hasSwitchingError := strings.Contains(errorMsg, "cannot switch") || strings.Contains(errorMsg, "user switching not supported") + assert.True(t, hasPrivilegedError || hasSwitchingError, + "Expected privileged user or user switching error, got: %s", errorMsg) + } + } else { + if tt.username == "root" || strings.ToLower(tt.username) == "administrator" { + // For privileged users, we expect either success or a different error + // (like user not found), but not the "login disabled" error + if err != nil { + assert.NotContains(t, err.Error(), "privileged user login is disabled") + } + } else { + // For regular users, lookup should generally succeed or fall back gracefully + // Note: may return current user as fallback + assert.NotNil(t, user) + } + } + }) + } +} + +func TestServer_PortForwardingRestriction(t *testing.T) { + // Test that the port forwarding callbacks properly respect configuration flags + // This is a unit test of the callback logic, not a full integration test + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + tests := []struct { + name string + allowLocalForwarding bool + allowRemoteForwarding bool + description string + }{ + { + name: "all forwarding allowed", + allowLocalForwarding: true, + allowRemoteForwarding: true, + description: "Both local and remote forwarding should be allowed", + }, + { + name: "local forwarding disabled", + allowLocalForwarding: false, + allowRemoteForwarding: true, + description: "Local forwarding should be denied when disabled", + }, + { + name: "remote forwarding disabled", + allowLocalForwarding: true, + allowRemoteForwarding: false, + description: "Remote forwarding should be denied when disabled", + }, + { + name: "all forwarding disabled", + allowLocalForwarding: false, + allowRemoteForwarding: false, + description: "Both forwarding types should be denied when disabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create server with specific configuration + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowLocalPortForwarding(tt.allowLocalForwarding) + server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding) + + // We need to access the internal configuration to simulate the callback tests + // Since the callbacks are created inside the Start method, we'll test the logic directly + + // Test the configuration values are set correctly + server.mu.RLock() + allowLocal := server.allowLocalPortForwarding + allowRemote := server.allowRemotePortForwarding + server.mu.RUnlock() + + assert.Equal(t, tt.allowLocalForwarding, allowLocal, "Local forwarding configuration should be set correctly") + assert.Equal(t, tt.allowRemoteForwarding, allowRemote, "Remote forwarding configuration should be set correctly") + + // Simulate the callback logic + localResult := allowLocal // This would be the callback return value + remoteResult := allowRemote // This would be the callback return value + + assert.Equal(t, tt.allowLocalForwarding, localResult, + "Local port forwarding callback should return correct value") + assert.Equal(t, tt.allowRemoteForwarding, remoteResult, + "Remote port forwarding callback should return correct value") + }) + } +} + +func TestServer_PortConflictHandling(t *testing.T) { + // Test that multiple sessions requesting the same local port are handled naturally by the OS + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create server + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Get a free port for testing + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + testPort := ln.Addr().(*net.TCPAddr).Port + err = ln.Close() + require.NoError(t, err) + + // Connect first client + ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel1() + + client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + err := client1.Close() + assert.NoError(t, err) + }() + + // Connect second client + ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + + client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + defer func() { + err := client2.Close() + assert.NoError(t, err) + }() + + // First client binds to the test port + localAddr1 := fmt.Sprintf("127.0.0.1:%d", testPort) + remoteAddr := "127.0.0.1:80" + + // Start first client's port forwarding + done1 := make(chan error, 1) + go func() { + // This should succeed and hold the port + err := client1.LocalPortForward(ctx1, localAddr1, remoteAddr) + done1 <- err + }() + + // Give first client time to bind + time.Sleep(200 * time.Millisecond) + + // Second client tries to bind to same port + localAddr2 := fmt.Sprintf("127.0.0.1:%d", testPort) + + shortCtx, shortCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer shortCancel() + + err = client2.LocalPortForward(shortCtx, localAddr2, remoteAddr) + // Second client should fail due to "address already in use" + assert.Error(t, err, "Second client should fail to bind to same port") + if err != nil { + // The error should indicate the address is already in use + errMsg := strings.ToLower(err.Error()) + if runtime.GOOS == "windows" { + assert.Contains(t, errMsg, "only one usage of each socket address", + "Error should indicate port conflict") + } else { + assert.Contains(t, errMsg, "address already in use", + "Error should indicate port conflict") + } + } + + // Cancel first client's context and wait for it to finish + cancel1() + select { + case err1 := <-done1: + // Should get context cancelled or deadline exceeded + assert.Error(t, err1, "First client should exit when context cancelled") + case <-time.After(2 * time.Second): + t.Error("First client did not exit within timeout") + } +} + +func TestServer_IsPrivilegedUser(t *testing.T) { + + tests := []struct { + username string + expected bool + description string + }{ + { + username: "root", + expected: true, + description: "root should be considered privileged", + }, + { + username: "regular", + expected: false, + description: "regular user should not be privileged", + }, + { + username: "", + expected: false, + description: "empty username should not be privileged", + }, + } + + // Add Windows-specific tests + if runtime.GOOS == "windows" { + tests = append(tests, []struct { + username string + expected bool + description string + }{ + { + username: "Administrator", + expected: true, + description: "Administrator should be considered privileged on Windows", + }, + { + username: "administrator", + expected: true, + description: "administrator should be considered privileged on Windows (case insensitive)", + }, + }...) + } else { + // On non-Windows systems, Administrator should not be privileged + tests = append(tests, []struct { + username string + expected bool + description string + }{ + { + username: "Administrator", + expected: false, + description: "Administrator should not be privileged on non-Windows systems", + }, + }...) + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + result := isPrivilegedUsername(tt.username) + assert.Equal(t, tt.expected, result, tt.description) + }) + } +} diff --git a/client/ssh/server/server_test.go b/client/ssh/server/server_test.go new file mode 100644 index 000000000..661068539 --- /dev/null +++ b/client/ssh/server/server_test.go @@ -0,0 +1,441 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "os/user" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + + nbssh "github.com/netbirdio/netbird/client/ssh" +) + +func TestServer_StartStop(t *testing.T) { + key, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: key, + JWT: nil, + } + server := New(serverConfig) + + err = server.Stop() + assert.NoError(t, err) +} + +func TestSSHServerIntegration(t *testing.T) { + // Generate host key for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Create server with random port + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + + // Start server in background + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + // Get a free port + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key for verification + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user for test") + + // Create SSH client config + config := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 3 * time.Second, + } + + // Connect to SSH server + client, err := cryptossh.Dial("tcp", serverAddr, config) + require.NoError(t, err) + defer func() { + if err := client.Close(); err != nil { + t.Logf("close client: %v", err) + } + }() + + // Test creating a session + session, err := client.NewSession() + require.NoError(t, err) + defer func() { + if err := session.Close(); err != nil { + t.Logf("close session: %v", err) + } + }() + + // Note: Since we don't have a real shell environment in tests, + // we can't test actual command execution, but we can verify + // the connection and authentication work + t.Log("SSH connection and authentication successful") +} + +func TestSSHServerMultipleConnections(t *testing.T) { + // Generate host key for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Create server + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user for test") + + config := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 3 * time.Second, + } + + // Test multiple concurrent connections + const numConnections = 5 + results := make(chan error, numConnections) + + for i := 0; i < numConnections; i++ { + go func(id int) { + client, err := cryptossh.Dial("tcp", serverAddr, config) + if err != nil { + results <- fmt.Errorf("connection %d failed: %w", id, err) + return + } + defer func() { + _ = client.Close() // Ignore error in test goroutine + }() + + session, err := client.NewSession() + if err != nil { + results <- fmt.Errorf("session %d failed: %w", id, err) + return + } + defer func() { + _ = session.Close() // Ignore error in test goroutine + }() + + results <- nil + }(i) + } + + // Wait for all connections to complete + for i := 0; i < numConnections; i++ { + select { + case err := <-results: + assert.NoError(t, err) + case <-time.After(10 * time.Second): + t.Fatalf("Connection %d timed out", i) + } + } +} + +func TestSSHServerNoAuthMode(t *testing.T) { + // Generate host key for server + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + // Create server + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Generate a client private key for SSH protocol (server doesn't check it) + clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user for test") + + // Try to connect with client key + config := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(clientSigner), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 3 * time.Second, + } + + // This should succeed in no-auth mode (server doesn't verify keys) + conn, err := cryptossh.Dial("tcp", serverAddr, config) + assert.NoError(t, err, "Connection should succeed in no-auth mode") + if conn != nil { + assert.NoError(t, conn.Close()) + } +} + +func TestSSHServerStartStopCycle(t *testing.T) { + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + serverAddr := "127.0.0.1:0" + + // Test multiple start/stop cycles + for i := 0; i < 3; i++ { + t.Logf("Start/stop cycle %d", i+1) + + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case <-started: + case err := <-errChan: + t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err) + case <-time.After(5 * time.Second): + t.Fatalf("Cycle %d: Server start timeout", i+1) + } + + err = server.Stop() + require.NoError(t, err, "Cycle %d: Stop should succeed", i+1) + } +} + +func TestSSHServer_WindowsShellHandling(t *testing.T) { + if testing.Short() { + t.Skip("Skipping Windows shell test in short mode") + } + + server := &Server{} + + if runtime.GOOS == "windows" { + // Test Windows cmd.exe shell behavior + args := server.getShellCommandArgs("cmd.exe", "echo test") + assert.Equal(t, "cmd.exe", args[0]) + assert.Equal(t, "-Command", args[1]) + assert.Equal(t, "echo test", args[2]) + + // Test PowerShell behavior + args = server.getShellCommandArgs("powershell.exe", "echo test") + assert.Equal(t, "powershell.exe", args[0]) + assert.Equal(t, "-Command", args[1]) + assert.Equal(t, "echo test", args[2]) + } else { + // Test Unix shell behavior + args := server.getShellCommandArgs("/bin/sh", "echo test") + assert.Equal(t, "/bin/sh", args[0]) + assert.Equal(t, "-l", args[1]) + assert.Equal(t, "-c", args[2]) + assert.Equal(t, "echo test", args[3]) + } +} + +func TestSSHServer_PortForwardingConfiguration(t *testing.T) { + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + serverConfig1 := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server1 := New(serverConfig1) + + serverConfig2 := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server2 := New(serverConfig2) + + assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security") + assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security") + + server2.SetAllowLocalPortForwarding(true) + server2.SetAllowRemotePortForwarding(true) + + assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set") + assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set") +} diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go new file mode 100644 index 000000000..4e6d72098 --- /dev/null +++ b/client/ssh/server/session_handlers.go @@ -0,0 +1,168 @@ +package server + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + cryptossh "golang.org/x/crypto/ssh" +) + +// sessionHandler handles SSH sessions +func (s *Server) sessionHandler(session ssh.Session) { + sessionKey := s.registerSession(session) + + key := newAuthKey(session.User(), session.RemoteAddr()) + s.mu.Lock() + jwtUsername := s.pendingAuthJWT[key] + if jwtUsername != "" { + s.sessionJWTUsers[sessionKey] = jwtUsername + delete(s.pendingAuthJWT, key) + } + s.mu.Unlock() + + logger := log.WithField("session", sessionKey) + if jwtUsername != "" { + logger = logger.WithField("jwt_user", jwtUsername) + logger.Infof("SSH session started (JWT user: %s)", jwtUsername) + } else { + logger.Infof("SSH session started") + } + sessionStart := time.Now() + + defer s.unregisterSession(sessionKey, session) + defer func() { + duration := time.Since(sessionStart).Round(time.Millisecond) + if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { + logger.Warnf("close session after %v: %v", duration, err) + } + logger.Infof("SSH session closed after %v", duration) + }() + + privilegeResult, err := s.userPrivilegeCheck(session.User()) + if err != nil { + s.handlePrivError(logger, session, err) + return + } + + ptyReq, winCh, isPty := session.Pty() + hasCommand := len(session.Command()) > 0 + + switch { + case isPty && hasCommand: + // ssh -t - Pty command execution + s.handleCommand(logger, session, privilegeResult, winCh) + case isPty: + // ssh - Pty interactive session (login) + s.handlePty(logger, session, privilegeResult, ptyReq, winCh) + case hasCommand: + // ssh - non-Pty command execution + s.handleCommand(logger, session, privilegeResult, nil) + default: + s.rejectInvalidSession(logger, session) + } +} + +func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) { + if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil { + logger.Debugf(errWriteSession, err) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr()) +} + +func (s *Server) registerSession(session ssh.Session) SessionKey { + sessionID := session.Context().Value(ssh.ContextKeySessionID) + if sessionID == nil { + sessionID = fmt.Sprintf("%p", session) + } + + // Create a short 4-byte identifier from the full session ID + hasher := sha256.New() + hasher.Write([]byte(fmt.Sprintf("%v", sessionID))) + hash := hasher.Sum(nil) + shortID := hex.EncodeToString(hash[:4]) + + remoteAddr := session.RemoteAddr().String() + username := session.User() + sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID)) + + s.mu.Lock() + s.sessions[sessionKey] = session + s.mu.Unlock() + + return sessionKey +} + +func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) { + s.mu.Lock() + delete(s.sessions, sessionKey) + delete(s.sessionJWTUsers, sessionKey) + + // Cancel all port forwarding connections for this session + var connectionsToCancel []ConnectionKey + for key := range s.sessionCancels { + if strings.HasPrefix(string(key), string(sessionKey)+"-") { + connectionsToCancel = append(connectionsToCancel, key) + } + } + + for _, key := range connectionsToCancel { + if cancelFunc, exists := s.sessionCancels[key]; exists { + log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key) + cancelFunc() + delete(s.sessionCancels, key) + } + } + + if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil { + if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok { + delete(s.sshConnections, sshConn) + } + } + + s.mu.Unlock() +} + +func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) { + logger.Warnf("user privilege check failed: %v", err) + + errorMsg := s.buildUserLookupErrorMessage(err) + + if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil { + logger.Debugf(errWriteSession, writeErr) + } + if exitErr := session.Exit(1); exitErr != nil { + logSessionExitError(logger, exitErr) + } +} + +// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type +func (s *Server) buildUserLookupErrorMessage(err error) string { + var privilegedErr *PrivilegedUserError + + switch { + case errors.As(err, &privilegedErr): + if privilegedErr.Username == "root" { + return "root login is disabled on this SSH server\n" + } + return "privileged user access is disabled on this SSH server\n" + + case errors.Is(err, ErrPrivilegeRequired): + return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n" + + case errors.Is(err, ErrPrivilegedUserSwitch): + return "Cannot switch to privileged user - current user lacks required privileges\n" + + default: + return "User authentication failed\n" + } +} diff --git a/client/ssh/server/session_handlers_js.go b/client/ssh/server/session_handlers_js.go new file mode 100644 index 000000000..c35e4da0b --- /dev/null +++ b/client/ssh/server/session_handlers_js.go @@ -0,0 +1,22 @@ +//go:build js + +package server + +import ( + "fmt" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// handlePty is not supported on JS/WASM +func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { + errorMsg := "PTY sessions are not supported on WASM/JS platform\n" + if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil { + logger.Debugf(errWriteSession, err) + } + if err := session.Exit(1); err != nil { + logSessionExitError(logger, err) + } + return false +} diff --git a/client/ssh/server/sftp.go b/client/ssh/server/sftp.go new file mode 100644 index 000000000..c2b9f552b --- /dev/null +++ b/client/ssh/server/sftp.go @@ -0,0 +1,81 @@ +package server + +import ( + "fmt" + "io" + + "github.com/gliderlabs/ssh" + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" +) + +// SetAllowSFTP enables or disables SFTP support +func (s *Server) SetAllowSFTP(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowSFTP = allow +} + +// sftpSubsystemHandler handles SFTP subsystem requests +func (s *Server) sftpSubsystemHandler(sess ssh.Session) { + s.mu.RLock() + allowSFTP := s.allowSFTP + s.mu.RUnlock() + + if !allowSFTP { + log.Debugf("SFTP subsystem request denied: SFTP disabled") + if err := sess.Exit(1); err != nil { + log.Debugf("SFTP session exit failed: %v", err) + } + return + } + + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: sess.User(), + FeatureSupportsUserSwitch: true, + FeatureName: FeatureSFTP, + }) + + if !result.Allowed { + log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error) + if err := sess.Exit(1); err != nil { + log.Debugf("exit SFTP session: %v", err) + } + return + } + + log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username) + + if !result.RequiresUserSwitching { + if err := s.executeSftpDirect(sess); err != nil { + log.Errorf("SFTP direct execution: %v", err) + } + return + } + + if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil { + log.Errorf("SFTP privilege drop execution: %v", err) + } +} + +// executeSftpDirect executes SFTP directly without privilege dropping +func (s *Server) executeSftpDirect(sess ssh.Session) error { + log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User()) + + sftpServer, err := sftp.NewServer(sess) + if err != nil { + return fmt.Errorf("SFTP server creation: %w", err) + } + + defer func() { + if err := sftpServer.Close(); err != nil { + log.Debugf("failed to close sftp server: %v", err) + } + }() + + if err := sftpServer.Serve(); err != nil && err != io.EOF { + return fmt.Errorf("serve: %w", err) + } + + return nil +} diff --git a/client/ssh/server/sftp_js.go b/client/ssh/server/sftp_js.go new file mode 100644 index 000000000..3b27aeff4 --- /dev/null +++ b/client/ssh/server/sftp_js.go @@ -0,0 +1,12 @@ +//go:build js + +package server + +import ( + "os/user" +) + +// parseUserCredentials is not supported on JS/WASM +func (s *Server) parseUserCredentials(_ *user.User) (uint32, uint32, []uint32, error) { + return 0, 0, nil, errNotSupported +} diff --git a/client/ssh/server/sftp_test.go b/client/ssh/server/sftp_test.go new file mode 100644 index 000000000..32a3643e4 --- /dev/null +++ b/client/ssh/server/sftp_test.go @@ -0,0 +1,228 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "os" + "os/user" + "testing" + "time" + + "github.com/pkg/sftp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" + + "github.com/netbirdio/netbird/client/ssh" +) + +func TestSSHServer_SFTPSubsystem(t *testing.T) { + // Skip SFTP test when running as root due to protocol issues in some environments + if os.Geteuid() == 0 { + t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues") + } + + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create server with SFTP enabled + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowSFTP(true) + server.SetAllowRootLogin(true) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // (currentUser already obtained at function start) + + // Create SSH client connection + clientConfig := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 5 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig) + require.NoError(t, err, "SSH connection should succeed") + defer func() { + if err := conn.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + + // Create SFTP client + sftpClient, err := sftp.NewClient(conn) + require.NoError(t, err, "SFTP client creation should succeed") + defer func() { + if err := sftpClient.Close(); err != nil { + t.Logf("SFTP client close error: %v", err) + } + }() + + // Test basic SFTP operations + workingDir, err := sftpClient.Getwd() + assert.NoError(t, err, "Should be able to get working directory") + assert.NotEmpty(t, workingDir, "Working directory should not be empty") + + // Test directory listing + files, err := sftpClient.ReadDir(".") + assert.NoError(t, err, "Should be able to list current directory") + assert.NotNil(t, files, "File list should not be nil") +} + +func TestSSHServer_SFTPDisabled(t *testing.T) { + // Get current user for SSH connection + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Generate host key for server + hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Generate client key pair + clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) + require.NoError(t, err) + + // Create server with SFTP disabled + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowSFTP(false) + + // Start server + serverAddr := "127.0.0.1:0" + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", serverAddr) + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort, _ := netip.ParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + serverAddr = actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + // Parse client private key + signer, err := cryptossh.ParsePrivateKey(clientPrivKey) + require.NoError(t, err) + + // Parse server host key + hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey) + require.NoError(t, err) + hostPubKey := hostPrivParsed.PublicKey() + + // (currentUser already obtained at function start) + + // Create SSH client connection + clientConfig := &cryptossh.ClientConfig{ + User: currentUser.Username, + Auth: []cryptossh.AuthMethod{ + cryptossh.PublicKeys(signer), + }, + HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), + Timeout: 5 * time.Second, + } + + conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig) + require.NoError(t, err, "SSH connection should succeed") + defer func() { + if err := conn.Close(); err != nil { + t.Logf("connection close error: %v", err) + } + }() + + // Try to create SFTP client - should fail when SFTP is disabled + _, err = sftp.NewClient(conn) + assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled") +} diff --git a/client/ssh/server/sftp_unix.go b/client/ssh/server/sftp_unix.go new file mode 100644 index 000000000..44202bead --- /dev/null +++ b/client/ssh/server/sftp_unix.go @@ -0,0 +1,71 @@ +//go:build !windows + +package server + +import ( + "errors" + "fmt" + "os" + "os/exec" + "os/user" + "strconv" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// executeSftpWithPrivilegeDrop executes SFTP using Unix privilege dropping +func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error { + uid, gid, groups, err := s.parseUserCredentials(targetUser) + if err != nil { + return fmt.Errorf("parse user credentials: %w", err) + } + + sftpCmd, err := s.createSftpExecutorCommand(sess, uid, gid, groups, targetUser.HomeDir) + if err != nil { + return fmt.Errorf("create executor: %w", err) + } + + sftpCmd.Stdin = sess + sftpCmd.Stdout = sess + sftpCmd.Stderr = sess.Stderr() + + log.Tracef("starting SFTP with privilege dropping to user %s (UID=%d, GID=%d)", targetUser.Username, uid, gid) + + if err := sftpCmd.Start(); err != nil { + return fmt.Errorf("starting SFTP executor: %w", err) + } + + if err := sftpCmd.Wait(); err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + log.Tracef("SFTP process exited with code %d", exitError.ExitCode()) + return nil + } + return fmt.Errorf("exec: %w", err) + } + + return nil +} + +// createSftpExecutorCommand creates a command that spawns netbird ssh sftp for privilege dropping +func (s *Server) createSftpExecutorCommand(sess ssh.Session, uid, gid uint32, groups []uint32, workingDir string) (*exec.Cmd, error) { + netbirdPath, err := os.Executable() + if err != nil { + return nil, err + } + + args := []string{ + "ssh", "sftp", + "--uid", strconv.FormatUint(uint64(uid), 10), + "--gid", strconv.FormatUint(uint64(gid), 10), + "--working-dir", workingDir, + } + + for _, group := range groups { + args = append(args, "--groups", strconv.FormatUint(uint64(group), 10)) + } + + log.Tracef("creating SFTP executor command: %s %v", netbirdPath, args) + return exec.CommandContext(sess.Context(), netbirdPath, args...), nil +} diff --git a/client/ssh/server/sftp_windows.go b/client/ssh/server/sftp_windows.go new file mode 100644 index 000000000..dc532b9e7 --- /dev/null +++ b/client/ssh/server/sftp_windows.go @@ -0,0 +1,91 @@ +//go:build windows + +package server + +import ( + "errors" + "fmt" + "os" + "os/exec" + "os/user" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +// createSftpCommand creates a Windows SFTP command with user switching. +// The caller must close the returned token handle after starting the process. +func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) { + username, domain := s.parseUsername(targetUser.Username) + + netbirdPath, err := os.Executable() + if err != nil { + return nil, 0, fmt.Errorf("get netbird executable path: %w", err) + } + + args := []string{ + "ssh", "sftp", + "--working-dir", targetUser.HomeDir, + "--windows-username", username, + "--windows-domain", domain, + } + + pd := NewPrivilegeDropper() + token, err := pd.createToken(username, domain) + if err != nil { + return nil, 0, fmt.Errorf("create token: %w", err) + } + + defer func() { + if err := windows.CloseHandle(token); err != nil { + log.Warnf("failed to close impersonation token: %v", err) + } + }() + + cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir) + if err != nil { + return nil, 0, fmt.Errorf("create SFTP command: %w", err) + } + + log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username) + return cmd, primaryToken, nil +} + +// executeSftpCommand executes a Windows SFTP command with proper I/O handling +func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error { + defer func() { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + log.Debugf("close primary token: %v", err) + } + }() + + sftpCmd.Stdin = sess + sftpCmd.Stdout = sess + sftpCmd.Stderr = sess.Stderr() + + if err := sftpCmd.Start(); err != nil { + return fmt.Errorf("starting sftp executor: %w", err) + } + + if err := sftpCmd.Wait(); err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + log.Tracef("sftp process exited with code %d", exitError.ExitCode()) + return nil + } + + return fmt.Errorf("exec sftp: %w", err) + } + + return nil +} + +// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping +func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error { + sftpCmd, token, err := s.createSftpCommand(targetUser, sess) + if err != nil { + return fmt.Errorf("create sftp: %w", err) + } + return s.executeSftpCommand(sess, sftpCmd, token) +} diff --git a/client/ssh/server/shell.go b/client/ssh/server/shell.go new file mode 100644 index 000000000..fea9d2910 --- /dev/null +++ b/client/ssh/server/shell.go @@ -0,0 +1,180 @@ +package server + +import ( + "bufio" + "fmt" + "net" + "os" + "os/exec" + "os/user" + "runtime" + "strconv" + "strings" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +const ( + defaultUnixShell = "/bin/sh" + + pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name + powershellExe = "powershell.exe" +) + +// getUserShell returns the appropriate shell for the given user ID +// Handles all platform-specific logic and fallbacks consistently +func getUserShell(userID string) string { + switch runtime.GOOS { + case "windows": + return getWindowsUserShell() + default: + return getUnixUserShell(userID) + } +} + +// getWindowsUserShell returns the best shell for Windows users. +// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection +// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters. +// PowerShell provides safer argument handling and is available on all modern Windows systems. +// Order: pwsh.exe -> powershell.exe +func getWindowsUserShell() string { + if path, err := exec.LookPath(pwshExe); err == nil { + return path + } + if path, err := exec.LookPath(powershellExe); err == nil { + return path + } + + return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe` +} + +// getUnixUserShell returns the shell for Unix-like systems +func getUnixUserShell(userID string) string { + shell := getShellFromPasswd(userID) + if shell != "" { + return shell + } + + if shell := os.Getenv("SHELL"); shell != "" { + return shell + } + + return defaultUnixShell +} + +// getShellFromPasswd reads the shell from /etc/passwd for the given user ID +func getShellFromPasswd(userID string) string { + file, err := os.Open("/etc/passwd") + if err != nil { + return "" + } + defer func() { + if err := file.Close(); err != nil { + log.Warnf("close /etc/passwd file: %v", err) + } + }() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Split(line, ":") + if len(fields) < 7 { + continue + } + + // field 2 is UID + if fields[2] == userID { + shell := strings.TrimSpace(fields[6]) + return shell + } + } + + if err := scanner.Err(); err != nil { + log.Warnf("error reading /etc/passwd: %v", err) + } + + return "" +} + +// prepareUserEnv prepares environment variables for user execution +func prepareUserEnv(user *user.User, shell string) []string { + pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games" + if runtime.GOOS == "windows" { + pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0` + } + + return []string{ + fmt.Sprint("SHELL=" + shell), + fmt.Sprint("USER=" + user.Username), + fmt.Sprint("LOGNAME=" + user.Username), + fmt.Sprint("HOME=" + user.HomeDir), + "PATH=" + pathValue, + } +} + +// acceptEnv checks if environment variable from SSH client should be accepted +// This is a whitelist of variables that SSH clients can send to the server +func acceptEnv(envVar string) bool { + varName := envVar + if idx := strings.Index(envVar, "="); idx != -1 { + varName = envVar[:idx] + } + + exactMatches := []string{ + "LANG", + "LANGUAGE", + "TERM", + "COLORTERM", + "EDITOR", + "VISUAL", + "PAGER", + "LESS", + "LESSCHARSET", + "TZ", + } + + prefixMatches := []string{ + "LC_", + } + + for _, exact := range exactMatches { + if varName == exact { + return true + } + } + + for _, prefix := range prefixMatches { + if strings.HasPrefix(varName, prefix) { + return true + } + } + + return false +} + +// prepareSSHEnv prepares SSH protocol-specific environment variables +// These variables provide information about the SSH connection itself +func prepareSSHEnv(session ssh.Session) []string { + remoteAddr := session.RemoteAddr() + localAddr := session.LocalAddr() + + remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String()) + if err != nil { + remoteHost = remoteAddr.String() + remotePort = "0" + } + + localHost, localPort, err := net.SplitHostPort(localAddr.String()) + if err != nil { + localHost = localAddr.String() + localPort = strconv.Itoa(InternalSSHPort) + } + + return []string{ + // SSH_CLIENT format: "client_ip client_port server_port" + fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort), + // SSH_CONNECTION format: "client_ip client_port server_ip server_port" + fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort), + } +} diff --git a/client/ssh/server/test.go b/client/ssh/server/test.go new file mode 100644 index 000000000..20930c721 --- /dev/null +++ b/client/ssh/server/test.go @@ -0,0 +1,45 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "testing" + "time" +) + +func StartTestServer(t *testing.T, server *Server) string { + started := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + errChan <- err + return + } + actualAddr := ln.Addr().String() + if err := ln.Close(); err != nil { + errChan <- fmt.Errorf("close temp listener: %w", err) + return + } + + addrPort := netip.MustParseAddrPort(actualAddr) + if err := server.Start(context.Background(), addrPort); err != nil { + errChan <- err + return + } + started <- actualAddr + }() + + select { + case actualAddr := <-started: + return actualAddr + case err := <-errChan: + t.Fatalf("Server failed to start: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Server start timeout") + } + return "" +} diff --git a/client/ssh/server/user_utils.go b/client/ssh/server/user_utils.go new file mode 100644 index 000000000..799882cbb --- /dev/null +++ b/client/ssh/server/user_utils.go @@ -0,0 +1,411 @@ +package server + +import ( + "errors" + "fmt" + "os" + "os/user" + "runtime" + "strings" + + log "github.com/sirupsen/logrus" +) + +var ( + ErrPrivilegeRequired = errors.New("SeAssignPrimaryTokenPrivilege required for user switching - NetBird must run with elevated privileges") + ErrPrivilegedUserSwitch = errors.New("cannot switch to privileged user - current user lacks required privileges") +) + +// isPlatformUnix returns true for Unix-like platforms (Linux, macOS, etc.) +func isPlatformUnix() bool { + return getCurrentOS() != "windows" +} + +// Dependency injection variables for testing - allows mocking dynamic runtime checks +var ( + getCurrentUser = user.Current + lookupUser = user.Lookup + getCurrentOS = func() string { return runtime.GOOS } + getIsProcessPrivileged = isCurrentProcessPrivileged + + getEuid = os.Geteuid +) + +const ( + // FeatureSSHLogin represents SSH login operations for privilege checking + FeatureSSHLogin = "SSH login" + // FeatureSFTP represents SFTP operations for privilege checking + FeatureSFTP = "SFTP" +) + +// PrivilegeCheckRequest represents a privilege check request +type PrivilegeCheckRequest struct { + // Username being requested (empty = current user) + RequestedUsername string + FeatureSupportsUserSwitch bool // Does this feature/operation support user switching? + FeatureName string +} + +// PrivilegeCheckResult represents the result of a privilege check +type PrivilegeCheckResult struct { + // Allowed indicates whether the privilege check passed + Allowed bool + // User is the effective user to use for the operation (nil if not allowed) + User *user.User + // Error contains the reason for denial (nil if allowed) + Error error + // UsedFallback indicates we fell back to current user instead of requested user. + // This happens on Unix when running as an unprivileged user (e.g., in containers) + // where there's no point in user switching since we lack privileges anyway. + // When true, all privilege checks have already been performed and no additional + // privilege dropping or root checks are needed - the current user is the target. + UsedFallback bool + // RequiresUserSwitching indicates whether user switching will actually occur + // (false for fallback cases where no actual switching happens) + RequiresUserSwitching bool +} + +// CheckPrivileges performs comprehensive privilege checking for all SSH features. +// This is the single source of truth for privilege decisions across the SSH server. +func (s *Server) CheckPrivileges(req PrivilegeCheckRequest) PrivilegeCheckResult { + context, err := s.buildPrivilegeCheckContext(req.FeatureName) + if err != nil { + return PrivilegeCheckResult{Allowed: false, Error: err} + } + + // Handle empty username case - but still check root access controls + if req.RequestedUsername == "" { + if isPrivilegedUsername(context.currentUser.Username) && !context.allowRoot { + return PrivilegeCheckResult{ + Allowed: false, + Error: &PrivilegedUserError{Username: context.currentUser.Username}, + } + } + return PrivilegeCheckResult{ + Allowed: true, + User: context.currentUser, + RequiresUserSwitching: false, + } + } + + return s.checkUserRequest(context, req) +} + +// buildPrivilegeCheckContext gathers all the context needed for privilege checking +func (s *Server) buildPrivilegeCheckContext(featureName string) (*privilegeCheckContext, error) { + currentUser, err := getCurrentUser() + if err != nil { + return nil, fmt.Errorf("get current user for %s: %w", featureName, err) + } + + s.mu.RLock() + allowRoot := s.allowRootLogin + s.mu.RUnlock() + + return &privilegeCheckContext{ + currentUser: currentUser, + currentUserPrivileged: getIsProcessPrivileged(), + allowRoot: allowRoot, + }, nil +} + +// checkUserRequest handles normal privilege checking flow for specific usernames +func (s *Server) checkUserRequest(ctx *privilegeCheckContext, req PrivilegeCheckRequest) PrivilegeCheckResult { + if !ctx.currentUserPrivileged && isPlatformUnix() { + log.Debugf("Unix non-privileged shortcut: falling back to current user %s for %s (requested: %s)", + ctx.currentUser.Username, req.FeatureName, req.RequestedUsername) + return PrivilegeCheckResult{ + Allowed: true, + User: ctx.currentUser, + UsedFallback: true, + RequiresUserSwitching: false, + } + } + + resolvedUser, err := s.resolveRequestedUser(req.RequestedUsername) + if err != nil { + // Calculate if user switching would be required even if lookup failed + needsUserSwitching := !isSameUser(req.RequestedUsername, ctx.currentUser.Username) + return PrivilegeCheckResult{ + Allowed: false, + Error: err, + RequiresUserSwitching: needsUserSwitching, + } + } + + needsUserSwitching := !isSameResolvedUser(resolvedUser, ctx.currentUser) + + if isPrivilegedUsername(resolvedUser.Username) && !ctx.allowRoot { + return PrivilegeCheckResult{ + Allowed: false, + Error: &PrivilegedUserError{Username: resolvedUser.Username}, + RequiresUserSwitching: needsUserSwitching, + } + } + + if needsUserSwitching && !req.FeatureSupportsUserSwitch { + return PrivilegeCheckResult{ + Allowed: false, + Error: fmt.Errorf("%s: user switching not supported by this feature", req.FeatureName), + RequiresUserSwitching: needsUserSwitching, + } + } + + return PrivilegeCheckResult{ + Allowed: true, + User: resolvedUser, + RequiresUserSwitching: needsUserSwitching, + } +} + +// resolveRequestedUser resolves a username to its canonical user identity +func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, error) { + if requestedUsername == "" { + return getCurrentUser() + } + + if err := validateUsername(requestedUsername); err != nil { + return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err) + } + + u, err := lookupUser(requestedUsername) + if err != nil { + return nil, &UserNotFoundError{Username: requestedUsername, Cause: err} + } + return u, nil +} + +// isSameResolvedUser compares two resolved user identities +func isSameResolvedUser(user1, user2 *user.User) bool { + if user1 == nil || user2 == nil { + return user1 == user2 + } + return user1.Uid == user2.Uid +} + +// privilegeCheckContext holds all context needed for privilege checking +type privilegeCheckContext struct { + currentUser *user.User + currentUserPrivileged bool + allowRoot bool +} + +// isSameUser checks if two usernames refer to the same user +// SECURITY: This function must be conservative - it should only return true +// when we're certain both usernames refer to the exact same user identity +func isSameUser(requestedUsername, currentUsername string) bool { + // Empty requested username means current user + if requestedUsername == "" { + return true + } + + // Exact match (most common case) + if getCurrentOS() == "windows" { + if strings.EqualFold(requestedUsername, currentUsername) { + return true + } + } else { + if requestedUsername == currentUsername { + return true + } + } + + // Windows domain resolution: only allow domain stripping when comparing + // a bare username against the current user's domain-qualified name + if getCurrentOS() == "windows" { + return isWindowsSameUser(requestedUsername, currentUsername) + } + + return false +} + +// isWindowsSameUser handles Windows-specific user comparison with domain logic +func isWindowsSameUser(requestedUsername, currentUsername string) bool { + // Extract domain and username parts + extractParts := func(name string) (domain, user string) { + // Handle DOMAIN\username format + if idx := strings.LastIndex(name, `\`); idx != -1 { + return name[:idx], name[idx+1:] + } + // Handle user@domain.com format + if idx := strings.Index(name, "@"); idx != -1 { + return name[idx+1:], name[:idx] + } + // No domain specified - local machine + return "", name + } + + reqDomain, reqUser := extractParts(requestedUsername) + curDomain, curUser := extractParts(currentUsername) + + // Case-insensitive username comparison + if !strings.EqualFold(reqUser, curUser) { + return false + } + + // If requested username has no domain, it refers to local machine user + // Allow this to match the current user regardless of current user's domain + if reqDomain == "" { + return true + } + + // If both have domains, they must match exactly (case-insensitive) + return strings.EqualFold(reqDomain, curDomain) +} + +// SetAllowRootLogin configures root login access +func (s *Server) SetAllowRootLogin(allow bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.allowRootLogin = allow +} + +// userNameLookup performs user lookup with root login permission check +func (s *Server) userNameLookup(username string) (*user.User, error) { + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: username, + FeatureSupportsUserSwitch: true, + FeatureName: FeatureSSHLogin, + }) + + if !result.Allowed { + return nil, result.Error + } + + return result.User, nil +} + +// userPrivilegeCheck performs user lookup with full privilege check result +func (s *Server) userPrivilegeCheck(username string) (PrivilegeCheckResult, error) { + result := s.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: username, + FeatureSupportsUserSwitch: true, + FeatureName: FeatureSSHLogin, + }) + + if !result.Allowed { + return result, result.Error + } + + return result, nil +} + +// isPrivilegedUsername checks if the given username represents a privileged user across platforms. +// On Unix: root +// On Windows: Administrator, SYSTEM (case-insensitive) +// Handles domain-qualified usernames like "DOMAIN\Administrator" or "user@domain.com" +func isPrivilegedUsername(username string) bool { + if getCurrentOS() != "windows" { + return username == "root" + } + + bareUsername := username + // Handle Windows domain format: DOMAIN\username + if idx := strings.LastIndex(username, `\`); idx != -1 { + bareUsername = username[idx+1:] + } + // Handle email-style format: username@domain.com + if idx := strings.Index(bareUsername, "@"); idx != -1 { + bareUsername = bareUsername[:idx] + } + + return isWindowsPrivilegedUser(bareUsername) +} + +// isWindowsPrivilegedUser checks if a bare username (domain already stripped) represents a Windows privileged account +func isWindowsPrivilegedUser(bareUsername string) bool { + // common privileged usernames (case insensitive) + privilegedNames := []string{ + "administrator", + "admin", + "root", + "system", + "localsystem", + "networkservice", + "localservice", + } + + usernameLower := strings.ToLower(bareUsername) + for _, privilegedName := range privilegedNames { + if usernameLower == privilegedName { + return true + } + } + + // computer accounts (ending with $) are not privileged by themselves + // They only gain privileges through group membership or specific SIDs + + if targetUser, err := lookupUser(bareUsername); err == nil { + return isWindowsPrivilegedSID(targetUser.Uid) + } + + return false +} + +// isWindowsPrivilegedSID checks if a Windows SID represents a privileged account +func isWindowsPrivilegedSID(sid string) bool { + privilegedSIDs := []string{ + "S-1-5-18", // Local System (SYSTEM) + "S-1-5-19", // Local Service (NT AUTHORITY\LOCAL SERVICE) + "S-1-5-20", // Network Service (NT AUTHORITY\NETWORK SERVICE) + "S-1-5-32-544", // Administrators group (BUILTIN\Administrators) + "S-1-5-500", // Built-in Administrator account (local machine RID 500) + } + + for _, privilegedSID := range privilegedSIDs { + if sid == privilegedSID { + return true + } + } + + // Check for domain administrator accounts (RID 500 in any domain) + // Format: S-1-5-21-domain-domain-domain-500 + // This is reliable as RID 500 is reserved for the domain Administrator account + if strings.HasPrefix(sid, "S-1-5-21-") && strings.HasSuffix(sid, "-500") { + return true + } + + // Check for other well-known privileged RIDs in domain contexts + // RID 512 = Domain Admins group, RID 516 = Domain Controllers group + if strings.HasPrefix(sid, "S-1-5-21-") { + if strings.HasSuffix(sid, "-512") || // Domain Admins group + strings.HasSuffix(sid, "-516") || // Domain Controllers group + strings.HasSuffix(sid, "-519") { // Enterprise Admins group + return true + } + } + + return false +} + +// isCurrentProcessPrivileged checks if the current process is running with elevated privileges. +// On Unix systems, this means running as root (UID 0). +// On Windows, this means running as Administrator or SYSTEM. +func isCurrentProcessPrivileged() bool { + if getCurrentOS() == "windows" { + return isWindowsElevated() + } + return getEuid() == 0 +} + +// isWindowsElevated checks if the current process is running with elevated privileges on Windows +func isWindowsElevated() bool { + currentUser, err := getCurrentUser() + if err != nil { + log.Errorf("failed to get current user for privilege check, assuming non-privileged: %v", err) + return false + } + + if isWindowsPrivilegedSID(currentUser.Uid) { + log.Debugf("Windows user switching supported: running as privileged SID %s", currentUser.Uid) + return true + } + + if isPrivilegedUsername(currentUser.Username) { + log.Debugf("Windows user switching supported: running as privileged username %s", currentUser.Username) + return true + } + + log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid) + return false +} diff --git a/client/ssh/server/user_utils_js.go b/client/ssh/server/user_utils_js.go new file mode 100644 index 000000000..163b24c6c --- /dev/null +++ b/client/ssh/server/user_utils_js.go @@ -0,0 +1,8 @@ +//go:build js + +package server + +// validateUsername is not supported on JS/WASM +func validateUsername(_ string) error { + return errNotSupported +} diff --git a/client/ssh/server/user_utils_test.go b/client/ssh/server/user_utils_test.go new file mode 100644 index 000000000..637dc10d0 --- /dev/null +++ b/client/ssh/server/user_utils_test.go @@ -0,0 +1,908 @@ +package server + +import ( + "errors" + "os/user" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Test helper functions +func createTestUser(username, uid, gid, homeDir string) *user.User { + return &user.User{ + Uid: uid, + Gid: gid, + Username: username, + Name: username, + HomeDir: homeDir, + } +} + +// Test dependency injection setup - injects platform dependencies to test real logic +func setupTestDependencies(currentUser *user.User, currentUserErr error, os string, euid int, lookupUsers map[string]*user.User, lookupErrors map[string]error) func() { + // Store originals + originalGetCurrentUser := getCurrentUser + originalLookupUser := lookupUser + originalGetCurrentOS := getCurrentOS + originalGetEuid := getEuid + + // Reset caches to ensure clean test state + + // Set test values - inject platform dependencies + getCurrentUser = func() (*user.User, error) { + return currentUser, currentUserErr + } + + lookupUser = func(username string) (*user.User, error) { + if err, exists := lookupErrors[username]; exists { + return nil, err + } + if userObj, exists := lookupUsers[username]; exists { + return userObj, nil + } + return nil, errors.New("user: unknown user " + username) + } + + getCurrentOS = func() string { + return os + } + + getEuid = func() int { + return euid + } + + // Mock privilege detection based on the test user + getIsProcessPrivileged = func() bool { + if currentUser == nil { + return false + } + // Check both username and SID for Windows systems + if os == "windows" && isWindowsPrivilegedSID(currentUser.Uid) { + return true + } + return isPrivilegedUsername(currentUser.Username) + } + + // Return cleanup function + return func() { + getCurrentUser = originalGetCurrentUser + lookupUser = originalLookupUser + getCurrentOS = originalGetCurrentOS + getEuid = originalGetEuid + + getIsProcessPrivileged = isCurrentProcessPrivileged + + // Reset caches after test + } +} + +func TestCheckPrivileges_ComprehensiveMatrix(t *testing.T) { + tests := []struct { + name string + os string + euid int + currentUser *user.User + requestedUsername string + featureSupportsUserSwitch bool + allowRoot bool + lookupUsers map[string]*user.User + expectedAllowed bool + expectedRequiresSwitch bool + }{ + { + name: "linux_root_can_switch_to_alice", + os: "linux", + euid: 0, // Root process + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "alice", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "1000", "1000", "/home/alice"), + }, + expectedAllowed: true, + expectedRequiresSwitch: true, + }, + { + name: "linux_non_root_fallback_to_current_user", + os: "linux", + euid: 1000, // Non-root process + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "bob", + featureSupportsUserSwitch: true, + allowRoot: true, + expectedAllowed: true, // Should fallback to current user (alice) + expectedRequiresSwitch: false, // Fallback means no actual switching + }, + { + name: "windows_admin_can_switch_to_alice", + os: "windows", + euid: 1000, // Irrelevant on Windows + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "alice", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + }, + expectedAllowed: true, + expectedRequiresSwitch: true, + }, + { + name: "windows_non_admin_no_fallback_hard_failure", + os: "windows", + euid: 1000, // Irrelevant on Windows + currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"), + requestedUsername: "bob", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "bob": createTestUser("bob", "S-1-5-21-123456789-123456789-123456789-1002", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\bob"), + }, + expectedAllowed: true, // Let OS decide - deferred security check + expectedRequiresSwitch: true, // Different user was requested + }, + // Comprehensive test matrix: non-root linux with different allowRoot settings + { + name: "linux_non_root_request_root_allowRoot_false", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: true, // Fallback allows access regardless of root setting + expectedRequiresSwitch: false, // Fallback case, no switching + }, + { + name: "linux_non_root_request_root_allowRoot_true", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + expectedAllowed: true, // Should fallback to alice (non-privileged process) + expectedRequiresSwitch: false, // Fallback means no actual switching + }, + // Windows admin test matrix + { + name: "windows_admin_request_root_allowRoot_false", + os: "windows", + euid: 1000, + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: false, // Root not allowed + expectedRequiresSwitch: true, + }, + { + name: "windows_admin_request_root_allowRoot_true", + os: "windows", + euid: 1000, + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // Windows user switching should work like Unix + expectedRequiresSwitch: true, + }, + // Windows non-admin test matrix + { + name: "windows_non_admin_request_root_allowRoot_false", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: false, // Root not allowed (allowRoot=false takes precedence) + expectedRequiresSwitch: true, + }, + { + name: "windows_system_account_allowRoot_false", + os: "windows", + euid: 1000, + currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: false, // Root not allowed + expectedRequiresSwitch: true, + }, + { + name: "windows_system_account_allowRoot_true", + os: "windows", + euid: 1000, + currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // SYSTEM can switch to root + expectedRequiresSwitch: true, + }, + { + name: "windows_non_admin_request_root_allowRoot_true", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "root", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // Let OS decide - deferred security check + expectedRequiresSwitch: true, + }, + + // Feature doesn't support user switching scenarios + { + name: "linux_root_feature_no_user_switching_same_user", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "root", // Same user + featureSupportsUserSwitch: false, + allowRoot: true, + lookupUsers: map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + }, + expectedAllowed: true, // Same user should work regardless of feature support + expectedRequiresSwitch: false, + }, + { + name: "linux_root_feature_no_user_switching_different_user", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "alice", + featureSupportsUserSwitch: false, // Feature doesn't support switching + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "1000", "1000", "/home/alice"), + }, + expectedAllowed: false, // Should deny because feature doesn't support switching + expectedRequiresSwitch: true, + }, + + // Empty username (current user) scenarios + { + name: "linux_non_root_current_user_empty_username", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + requestedUsername: "", // Empty = current user + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: true, // Current user should always work + expectedRequiresSwitch: false, + }, + { + name: "linux_root_current_user_empty_username_root_not_allowed", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "", // Empty = current user (root) + featureSupportsUserSwitch: true, + allowRoot: false, // Root not allowed + expectedAllowed: false, // Should deny root even when it's current user + expectedRequiresSwitch: false, + }, + + // User not found scenarios + { + name: "linux_root_user_not_found", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + requestedUsername: "nonexistent", + featureSupportsUserSwitch: true, + allowRoot: true, + lookupUsers: map[string]*user.User{}, // No users defined = user not found + expectedAllowed: false, // Should fail due to user not found + expectedRequiresSwitch: true, + }, + + // Windows feature doesn't support user switching + { + name: "windows_admin_feature_no_user_switching_different_user", + os: "windows", + euid: 1000, + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + requestedUsername: "alice", + featureSupportsUserSwitch: false, // Feature doesn't support switching + allowRoot: true, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + }, + expectedAllowed: false, // Should deny because feature doesn't support switching + expectedRequiresSwitch: true, + }, + + // Windows regular user scenarios (non-admin) + { + name: "windows_regular_user_same_user", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "alice", // Same user + featureSupportsUserSwitch: true, + allowRoot: false, + lookupUsers: map[string]*user.User{ + "alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + }, + expectedAllowed: true, // Regular user accessing themselves should work + expectedRequiresSwitch: false, // No switching for same user + }, + { + name: "windows_regular_user_empty_username", + os: "windows", + euid: 1000, + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"), + requestedUsername: "", // Empty = current user + featureSupportsUserSwitch: true, + allowRoot: false, + expectedAllowed: true, // Current user should always work + expectedRequiresSwitch: false, // No switching for current user + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Inject platform dependencies to test real logic + cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, tt.lookupUsers, nil) + defer cleanup() + + server := &Server{allowRootLogin: tt.allowRoot} + + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: tt.requestedUsername, + FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch, + FeatureName: "SSH login", + }) + + assert.Equal(t, tt.expectedAllowed, result.Allowed) + assert.Equal(t, tt.expectedRequiresSwitch, result.RequiresUserSwitching) + }) + } +} + +func TestUsedFallback_MeansNoPrivilegeDropping(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Fallback mechanism is Unix-specific") + } + + // Create test scenario where fallback should occur + server := &Server{allowRootLogin: true} + + // Mock dependencies to simulate non-privileged user + originalGetCurrentUser := getCurrentUser + originalGetIsProcessPrivileged := getIsProcessPrivileged + + defer func() { + getCurrentUser = originalGetCurrentUser + getIsProcessPrivileged = originalGetIsProcessPrivileged + + }() + + // Set up mocks for fallback scenario + getCurrentUser = func() (*user.User, error) { + return createTestUser("netbird", "1000", "1000", "/var/lib/netbird"), nil + } + getIsProcessPrivileged = func() bool { return false } // Non-privileged + + // Request different user - should fallback + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: "alice", + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + // Verify fallback occurred + assert.True(t, result.Allowed, "Should allow with fallback") + assert.True(t, result.UsedFallback, "Should indicate fallback was used") + assert.Equal(t, "netbird", result.User.Username, "Should return current user") + assert.False(t, result.RequiresUserSwitching, "Should not require switching when fallback is used") + + // Key assertion: When UsedFallback is true, no privilege dropping should be needed + // because all privilege checks have already been performed and we're using current user + t.Logf("UsedFallback=true means: current user (%s) is the target, no privilege dropping needed", + result.User.Username) +} + +func TestPrivilegedUsernameDetection(t *testing.T) { + tests := []struct { + name string + username string + platform string + privileged bool + }{ + // Unix/Linux tests + {"unix_root", "root", "linux", true}, + {"unix_regular_user", "alice", "linux", false}, + {"unix_root_capital", "Root", "linux", false}, // Case-sensitive + + // Windows tests + {"windows_administrator", "Administrator", "windows", true}, + {"windows_system", "SYSTEM", "windows", true}, + {"windows_admin", "admin", "windows", true}, + {"windows_admin_lowercase", "administrator", "windows", true}, // Case-insensitive + {"windows_domain_admin", "DOMAIN\\Administrator", "windows", true}, + {"windows_email_admin", "admin@domain.com", "windows", true}, + {"windows_regular_user", "alice", "windows", false}, + {"windows_domain_user", "DOMAIN\\alice", "windows", false}, + {"windows_localsystem", "localsystem", "windows", true}, + {"windows_networkservice", "networkservice", "windows", true}, + {"windows_localservice", "localservice", "windows", true}, + + // Computer accounts (these depend on current user context in real implementation) + {"windows_computer_account", "WIN2K19-C2$", "windows", false}, // Computer account by itself not privileged + {"windows_domain_computer", "DOMAIN\\COMPUTER$", "windows", false}, // Domain computer account + + // Cross-platform + {"root_on_windows", "root", "windows", true}, // Root should be privileged everywhere + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock the platform for this test + cleanup := setupTestDependencies(nil, nil, tt.platform, 1000, nil, nil) + defer cleanup() + + result := isPrivilegedUsername(tt.username) + assert.Equal(t, tt.privileged, result) + }) + } +} + +func TestWindowsPrivilegedSIDDetection(t *testing.T) { + tests := []struct { + name string + sid string + privileged bool + description string + }{ + // Well-known system accounts + {"system_account", "S-1-5-18", true, "Local System (SYSTEM)"}, + {"local_service", "S-1-5-19", true, "Local Service"}, + {"network_service", "S-1-5-20", true, "Network Service"}, + {"administrators_group", "S-1-5-32-544", true, "Administrators group"}, + {"builtin_administrator", "S-1-5-500", true, "Built-in Administrator"}, + + // Domain accounts + {"domain_administrator", "S-1-5-21-1234567890-1234567890-1234567890-500", true, "Domain Administrator (RID 500)"}, + {"domain_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-512", true, "Domain Admins group"}, + {"domain_controllers_group", "S-1-5-21-1234567890-1234567890-1234567890-516", true, "Domain Controllers group"}, + {"enterprise_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-519", true, "Enterprise Admins group"}, + + // Regular users + {"regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1001", false, "Regular domain user"}, + {"another_regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1234", false, "Another regular user"}, + {"local_user", "S-1-5-21-1234567890-1234567890-1234567890-1000", false, "Local regular user"}, + + // Groups that are not privileged + {"domain_users", "S-1-5-21-1234567890-1234567890-1234567890-513", false, "Domain Users group"}, + {"power_users", "S-1-5-32-547", false, "Power Users group"}, + + // Invalid SIDs + {"malformed_sid", "S-1-5-invalid", false, "Malformed SID"}, + {"empty_sid", "", false, "Empty SID"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isWindowsPrivilegedSID(tt.sid) + assert.Equal(t, tt.privileged, result, "Failed for %s: %s", tt.description, tt.sid) + }) + } +} + +func TestIsSameUser(t *testing.T) { + tests := []struct { + name string + user1 string + user2 string + os string + expected bool + }{ + // Basic cases + {"same_username", "alice", "alice", "linux", true}, + {"different_username", "alice", "bob", "linux", false}, + + // Linux (no domain processing) + {"linux_domain_vs_bare", "DOMAIN\\alice", "alice", "linux", false}, + {"linux_email_vs_bare", "alice@domain.com", "alice", "linux", false}, + {"linux_same_literal", "DOMAIN\\alice", "DOMAIN\\alice", "linux", true}, + + // Windows (with domain processing) - Note: parameter order is (requested, current, os, expected) + {"windows_domain_vs_bare", "alice", "DOMAIN\\alice", "windows", true}, // bare username matches domain current user + {"windows_email_vs_bare", "alice", "alice@domain.com", "windows", true}, // bare username matches email current user + {"windows_different_domains_same_user", "DOMAIN1\\alice", "DOMAIN2\\alice", "windows", false}, // SECURITY: different domains = different users + {"windows_case_insensitive", "Alice", "alice", "windows", true}, + {"windows_different_users", "DOMAIN\\alice", "DOMAIN\\bob", "windows", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up OS mock + cleanup := setupTestDependencies(nil, nil, tt.os, 1000, nil, nil) + defer cleanup() + + result := isSameUser(tt.user1, tt.user2) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestUsernameValidation_Unix(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Unix-specific username validation tests") + } + + tests := []struct { + name string + username string + wantErr bool + errMsg string + }{ + // Valid usernames (Unix/POSIX) + {"valid_alphanumeric", "user123", false, ""}, + {"valid_with_dots", "user.name", false, ""}, + {"valid_with_hyphens", "user-name", false, ""}, + {"valid_with_underscores", "user_name", false, ""}, + {"valid_uppercase", "UserName", false, ""}, + {"valid_starting_with_digit", "123user", false, ""}, + {"valid_starting_with_dot", ".hidden", false, ""}, + + // Invalid usernames (Unix/POSIX) + {"empty_username", "", true, "username cannot be empty"}, + {"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"}, + {"username_starting_with_hyphen", "-user", true, "invalid characters"}, // POSIX restriction + {"username_with_spaces", "user name", true, "invalid characters"}, + {"username_with_shell_metacharacters", "user;rm", true, "invalid characters"}, + {"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"}, + {"username_with_pipe", "user|rm", true, "invalid characters"}, + {"username_with_ampersand", "user&rm", true, "invalid characters"}, + {"username_with_quotes", "user\"name", true, "invalid characters"}, + {"username_with_newline", "user\nname", true, "invalid characters"}, + {"reserved_dot", ".", true, "cannot be '.' or '..'"}, + {"reserved_dotdot", "..", true, "cannot be '.' or '..'"}, + {"username_with_at_symbol", "user@domain", true, "invalid characters"}, // Not allowed in bare Unix usernames + {"username_with_backslash", "user\\name", true, "invalid characters"}, // Not allowed in Unix usernames + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateUsername(tt.username) + if tt.wantErr { + assert.Error(t, err, "Should reject invalid username") + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text") + } + } else { + assert.NoError(t, err, "Should accept valid username") + } + }) + } +} + +func TestUsernameValidation_Windows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows-specific username validation tests") + } + + tests := []struct { + name string + username string + wantErr bool + errMsg string + }{ + // Valid usernames (Windows) + {"valid_alphanumeric", "user123", false, ""}, + {"valid_with_dots", "user.name", false, ""}, + {"valid_with_hyphens", "user-name", false, ""}, + {"valid_with_underscores", "user_name", false, ""}, + {"valid_uppercase", "UserName", false, ""}, + {"valid_starting_with_digit", "123user", false, ""}, + {"valid_starting_with_dot", ".hidden", false, ""}, + {"valid_starting_with_hyphen", "-user", false, ""}, // Windows allows this + {"valid_domain_username", "DOMAIN\\user", false, ""}, // Windows domain format + {"valid_email_username", "user@domain.com", false, ""}, // Windows email format + {"valid_machine_username", "MACHINE\\user", false, ""}, // Windows machine format + + // Invalid usernames (Windows) + {"empty_username", "", true, "username cannot be empty"}, + {"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"}, + {"username_with_spaces", "user name", true, "invalid characters"}, + {"username_with_shell_metacharacters", "user;rm", true, "invalid characters"}, + {"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"}, + {"username_with_pipe", "user|rm", true, "invalid characters"}, + {"username_with_ampersand", "user&rm", true, "invalid characters"}, + {"username_with_quotes", "user\"name", true, "invalid characters"}, + {"username_with_newline", "user\nname", true, "invalid characters"}, + {"username_with_brackets", "user[name]", true, "invalid characters"}, + {"username_with_colon", "user:name", true, "invalid characters"}, + {"username_with_semicolon", "user;name", true, "invalid characters"}, + {"username_with_equals", "user=name", true, "invalid characters"}, + {"username_with_comma", "user,name", true, "invalid characters"}, + {"username_with_plus", "user+name", true, "invalid characters"}, + {"username_with_asterisk", "user*name", true, "invalid characters"}, + {"username_with_question", "user?name", true, "invalid characters"}, + {"username_with_angles", "user", true, "invalid characters"}, + {"reserved_dot", ".", true, "cannot be '.' or '..'"}, + {"reserved_dotdot", "..", true, "cannot be '.' or '..'"}, + {"username_ending_with_period", "user.", true, "cannot end with a period"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateUsername(tt.username) + if tt.wantErr { + assert.Error(t, err, "Should reject invalid username") + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text") + } + } else { + assert.NoError(t, err, "Should accept valid username") + } + }) + } +} + +// Test real-world integration scenarios with actual platform capabilities +func TestCheckPrivileges_RealWorldScenarios(t *testing.T) { + tests := []struct { + name string + feature string + featureSupportsUserSwitch bool + requestedUsername string + allowRoot bool + expectedBehaviorPattern string + }{ + {"SSH_login_current_user", "SSH login", true, "", true, "should_allow_current_user"}, + {"SFTP_current_user", "SFTP", true, "", true, "should_allow_current_user"}, + {"port_forwarding_current_user", "port forwarding", false, "", true, "should_allow_current_user"}, + {"SSH_login_root_not_allowed", "SSH login", true, "root", false, "should_deny_root"}, + {"port_forwarding_different_user", "port forwarding", false, "differentuser", true, "should_deny_switching"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock privileged environment to ensure consistent test behavior across environments + cleanup := setupTestDependencies( + createTestUser("root", "0", "0", "/root"), // Running as root + nil, + runtime.GOOS, + 0, // euid 0 (root) + map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + "differentuser": createTestUser("differentuser", "1000", "1000", "/home/differentuser"), + }, + nil, + ) + defer cleanup() + + server := &Server{allowRootLogin: tt.allowRoot} + + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: tt.requestedUsername, + FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch, + FeatureName: tt.feature, + }) + + switch tt.expectedBehaviorPattern { + case "should_allow_current_user": + assert.True(t, result.Allowed, "Should allow current user access") + assert.False(t, result.RequiresUserSwitching, "Current user should not require switching") + case "should_deny_root": + assert.False(t, result.Allowed, "Should deny root when not allowed") + assert.Contains(t, result.Error.Error(), "root", "Should mention root in error") + case "should_deny_switching": + assert.False(t, result.Allowed, "Should deny when feature doesn't support switching") + assert.Contains(t, result.Error.Error(), "user switching not supported", "Should mention switching in error") + } + }) + } +} + +// Test with actual platform capabilities - no mocking +func TestCheckPrivileges_ActualPlatform(t *testing.T) { + // This test uses the REAL platform capabilities + server := &Server{allowRootLogin: true} + + // Test current user access - should always work + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: "", // Current user + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + assert.True(t, result.Allowed, "Current user should always be allowed") + assert.False(t, result.RequiresUserSwitching, "Current user should not require switching") + assert.NotNil(t, result.User, "Should return current user") + + // Test user switching capability based on actual platform + actualIsPrivileged := isCurrentProcessPrivileged() // REAL check + actualOS := runtime.GOOS // REAL check + + t.Logf("Platform capabilities: OS=%s, isPrivileged=%v, supportsUserSwitching=%v", + actualOS, actualIsPrivileged, actualIsPrivileged) + + // Test requesting different user + result = server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: "nonexistentuser", + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + switch { + case actualOS == "windows": + // Windows supports user switching but should fail on nonexistent user + assert.False(t, result.Allowed, "Windows should deny nonexistent user") + assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed") + assert.Contains(t, result.Error.Error(), "not found", + "Should indicate user not found") + case !actualIsPrivileged: + // Non-privileged Unix processes should fallback to current user + assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user") + assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens") + assert.True(t, result.UsedFallback, "Should indicate fallback was used") + assert.NotNil(t, result.User, "Should return current user") + default: + // Privileged Unix processes should attempt user lookup + assert.False(t, result.Allowed, "Should fail due to nonexistent user") + assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed") + assert.Contains(t, result.Error.Error(), "nonexistentuser", + "Should indicate user not found") + } +} + +// Test platform detection logic with dependency injection +func TestPlatformLogic_DependencyInjection(t *testing.T) { + tests := []struct { + name string + os string + euid int + currentUser *user.User + expectedIsProcessPrivileged bool + expectedSupportsUserSwitching bool + }{ + { + name: "linux_root_process", + os: "linux", + euid: 0, + currentUser: createTestUser("root", "0", "0", "/root"), + expectedIsProcessPrivileged: true, + expectedSupportsUserSwitching: true, + }, + { + name: "linux_non_root_process", + os: "linux", + euid: 1000, + currentUser: createTestUser("alice", "1000", "1000", "/home/alice"), + expectedIsProcessPrivileged: false, + expectedSupportsUserSwitching: false, + }, + { + name: "windows_admin_process", + os: "windows", + euid: 1000, // euid ignored on Windows + currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"), + expectedIsProcessPrivileged: true, + expectedSupportsUserSwitching: true, // Windows supports user switching when privileged + }, + { + name: "windows_regular_process", + os: "windows", + euid: 1000, // euid ignored on Windows + currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"), + expectedIsProcessPrivileged: false, + expectedSupportsUserSwitching: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Inject platform dependencies and test REAL logic + cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, nil, nil) + defer cleanup() + + // Test the actual functions with injected dependencies + actualIsPrivileged := isCurrentProcessPrivileged() + actualSupportsUserSwitching := actualIsPrivileged + + assert.Equal(t, tt.expectedIsProcessPrivileged, actualIsPrivileged, + "isCurrentProcessPrivileged() result mismatch") + assert.Equal(t, tt.expectedSupportsUserSwitching, actualSupportsUserSwitching, + "supportsUserSwitching() result mismatch") + + t.Logf("Platform: %s, EUID: %d, User: %s", tt.os, tt.euid, tt.currentUser.Username) + t.Logf("Results: isPrivileged=%v, supportsUserSwitching=%v", + actualIsPrivileged, actualSupportsUserSwitching) + }) + } +} + +func TestCheckPrivileges_WindowsElevatedUserSwitching(t *testing.T) { + // Test Windows elevated user switching scenarios with simplified privilege logic + tests := []struct { + name string + currentUser *user.User + requestedUsername string + allowRoot bool + expectedAllowed bool + expectedErrorContains string + }{ + { + name: "windows_admin_can_switch_to_alice", + currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"), + requestedUsername: "alice", + allowRoot: true, + expectedAllowed: true, + }, + { + name: "windows_non_admin_can_try_switch", + currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\\\Users\\\\alice"), + requestedUsername: "bob", + allowRoot: true, + expectedAllowed: true, // Privilege check allows it, OS will reject during execution + }, + { + name: "windows_system_can_switch_to_alice", + currentUser: createTestUser("SYSTEM", "S-1-5-18", "S-1-5-18", "C:\\\\Windows\\\\system32\\\\config\\\\systemprofile"), + requestedUsername: "alice", + allowRoot: true, + expectedAllowed: true, + }, + { + name: "windows_admin_root_not_allowed", + currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"), + requestedUsername: "root", + allowRoot: false, + expectedAllowed: false, + expectedErrorContains: "privileged user login is disabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup test dependencies with Windows OS and specified privileges + lookupUsers := map[string]*user.User{ + tt.requestedUsername: createTestUser(tt.requestedUsername, "1002", "1002", "C:\\\\Users\\\\"+tt.requestedUsername), + } + cleanup := setupTestDependencies(tt.currentUser, nil, "windows", 1000, lookupUsers, nil) + defer cleanup() + + server := &Server{allowRootLogin: tt.allowRoot} + + result := server.CheckPrivileges(PrivilegeCheckRequest{ + RequestedUsername: tt.requestedUsername, + FeatureSupportsUserSwitch: true, + FeatureName: "SSH login", + }) + + assert.Equal(t, tt.expectedAllowed, result.Allowed, + "Privilege check result should match expected for %s", tt.name) + + if !tt.expectedAllowed && tt.expectedErrorContains != "" { + assert.NotNil(t, result.Error, "Should have error when not allowed") + assert.Contains(t, result.Error.Error(), tt.expectedErrorContains, + "Error should contain expected message") + } + + if tt.expectedAllowed && tt.requestedUsername != "" && tt.currentUser.Username != tt.requestedUsername { + assert.True(t, result.RequiresUserSwitching, "Should require user switching for different user") + } + }) + } +} diff --git a/client/ssh/server/userswitching_js.go b/client/ssh/server/userswitching_js.go new file mode 100644 index 000000000..333c19259 --- /dev/null +++ b/client/ssh/server/userswitching_js.go @@ -0,0 +1,8 @@ +//go:build js + +package server + +// enableUserSwitching is not supported on JS/WASM +func enableUserSwitching() error { + return errNotSupported +} diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go new file mode 100644 index 000000000..06fefabd7 --- /dev/null +++ b/client/ssh/server/userswitching_unix.go @@ -0,0 +1,233 @@ +//go:build unix + +package server + +import ( + "errors" + "fmt" + "net" + "net/netip" + "os" + "os/exec" + "os/user" + "regexp" + "runtime" + "strconv" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" +) + +// POSIX portable filename character set regex: [a-zA-Z0-9._-] +// First character cannot be hyphen (POSIX requirement) +var posixUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9._][a-zA-Z0-9._-]*$`) + +// validateUsername validates that a username conforms to POSIX standards with security considerations +func validateUsername(username string) error { + if username == "" { + return errors.New("username cannot be empty") + } + + // POSIX allows up to 256 characters, but practical limit is 32 for compatibility + if len(username) > 32 { + return errors.New("username too long (max 32 characters)") + } + + if !posixUsernameRegex.MatchString(username) { + return errors.New("username contains invalid characters (must match POSIX portable filename character set)") + } + + if username == "." || username == ".." { + return fmt.Errorf("username cannot be '.' or '..'") + } + + // Warn if username is fully numeric (can cause issues with UID/username ambiguity) + if isFullyNumeric(username) { + log.Warnf("fully numeric username '%s' may cause issues with some commands", username) + } + + return nil +} + +// isFullyNumeric checks if username contains only digits +func isFullyNumeric(username string) bool { + for _, char := range username { + if char < '0' || char > '9' { + return false + } + } + return true +} + +// createPtyLoginCommand creates a Pty command using login for privileged processes +func (s *Server) createPtyLoginCommand(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) { + loginPath, args, err := s.getLoginCmd(localUser.Username, session.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("get login command: %w", err) + } + + execCmd := exec.CommandContext(session.Context(), loginPath, args...) + execCmd.Dir = localUser.HomeDir + execCmd.Env = s.preparePtyEnv(localUser, ptyReq, session) + + return execCmd, nil +} + +// getLoginCmd returns the login command and args for privileged Pty user switching +func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []string, error) { + loginPath, err := exec.LookPath("login") + if err != nil { + return "", nil, fmt.Errorf("login command not available: %w", err) + } + + addrPort, err := netip.ParseAddrPort(remoteAddr.String()) + if err != nil { + return "", nil, fmt.Errorf("parse remote address: %w", err) + } + + switch runtime.GOOS { + case "linux": + // Special handling for Arch Linux without /etc/pam.d/remote + if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") { + return loginPath, []string{"-f", username, "-p"}, nil + } + return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil + case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly": + return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil + default: + return "", nil, fmt.Errorf("unsupported Unix platform for login command: %s", runtime.GOOS) + } +} + +// fileExists checks if a file exists (helper for login command logic) +func (s *Server) fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +// parseUserCredentials extracts numeric UID, GID, and supplementary groups +func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []uint32, error) { + uid64, err := strconv.ParseUint(localUser.Uid, 10, 32) + if err != nil { + return 0, 0, nil, fmt.Errorf("invalid UID %s: %w", localUser.Uid, err) + } + uid := uint32(uid64) + + gid64, err := strconv.ParseUint(localUser.Gid, 10, 32) + if err != nil { + return 0, 0, nil, fmt.Errorf("invalid GID %s: %w", localUser.Gid, err) + } + gid := uint32(gid64) + + groups, err := s.getSupplementaryGroups(localUser.Username) + if err != nil { + log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err) + groups = []uint32{gid} + } + + return uid, gid, groups, nil +} + +// getSupplementaryGroups retrieves supplementary group IDs for a user +func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) { + u, err := user.Lookup(username) + if err != nil { + return nil, fmt.Errorf("lookup user %s: %w", username, err) + } + + groupIDStrings, err := u.GroupIds() + if err != nil { + return nil, fmt.Errorf("get group IDs for user %s: %w", username, err) + } + + groups := make([]uint32, len(groupIDStrings)) + for i, gidStr := range groupIDStrings { + gid64, err := strconv.ParseUint(gidStr, 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err) + } + groups[i] = uint32(gid64) + } + + return groups, nil +} + +// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping. +// Returns the command and a cleanup function (no-op on Unix). +func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) + + if err := validateUsername(localUser.Username); err != nil { + return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) + } + + uid, gid, groups, err := s.parseUserCredentials(localUser) + if err != nil { + return nil, nil, fmt.Errorf("parse user credentials: %w", err) + } + privilegeDropper := NewPrivilegeDropper() + config := ExecutorConfig{ + UID: uid, + GID: gid, + Groups: groups, + WorkingDir: localUser.HomeDir, + Shell: getUserShell(localUser.Uid), + Command: session.RawCommand(), + PTY: hasPty, + } + + cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config) + return cmd, func() {}, err +} + +// enableUserSwitching is a no-op on Unix systems +func enableUserSwitching() error { + return nil +} + +// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results +func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) { + localUser := privilegeResult.User + if localUser == nil { + return nil, errors.New("no user in privilege result") + } + + if privilegeResult.UsedFallback { + return s.createDirectPtyCommand(session, localUser, ptyReq), nil + } + + return s.createPtyLoginCommand(localUser, ptyReq, session) +} + +// createDirectPtyCommand creates a direct Pty command without privilege dropping +func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd { + log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username) + + shell := getUserShell(localUser.Uid) + args := s.getShellCommandArgs(shell, session.RawCommand()) + + cmd := exec.CommandContext(session.Context(), args[0], args[1:]...) + cmd.Dir = localUser.HomeDir + cmd.Env = s.preparePtyEnv(localUser, ptyReq, session) + + return cmd +} + +// preparePtyEnv prepares environment variables for Pty execution +func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) []string { + termType := ptyReq.Term + if termType == "" { + termType = "xterm-256color" + } + + env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) + env = append(env, prepareSSHEnv(session)...) + env = append(env, fmt.Sprintf("TERM=%s", termType)) + + for _, v := range session.Environ() { + if acceptEnv(v) { + env = append(env, v) + } + } + return env +} diff --git a/client/ssh/server/userswitching_windows.go b/client/ssh/server/userswitching_windows.go new file mode 100644 index 000000000..5a5f75fa4 --- /dev/null +++ b/client/ssh/server/userswitching_windows.go @@ -0,0 +1,274 @@ +//go:build windows + +package server + +import ( + "errors" + "fmt" + "os/exec" + "os/user" + "strings" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +// validateUsername validates Windows usernames according to SAM Account Name rules +func validateUsername(username string) error { + if username == "" { + return fmt.Errorf("username cannot be empty") + } + + usernameToValidate := extractUsernameFromDomain(username) + + if err := validateUsernameLength(usernameToValidate); err != nil { + return err + } + + if err := validateUsernameCharacters(usernameToValidate); err != nil { + return err + } + + if err := validateUsernameFormat(usernameToValidate); err != nil { + return err + } + + return nil +} + +// extractUsernameFromDomain extracts the username part from domain\username or username@domain format +func extractUsernameFromDomain(username string) string { + if idx := strings.LastIndex(username, `\`); idx != -1 { + return username[idx+1:] + } + if idx := strings.Index(username, "@"); idx != -1 { + return username[:idx] + } + return username +} + +// validateUsernameLength checks if username length is within Windows limits +func validateUsernameLength(username string) error { + if len(username) > 20 { + return fmt.Errorf("username too long (max 20 characters for Windows)") + } + return nil +} + +// validateUsernameCharacters checks for invalid characters in Windows usernames +func validateUsernameCharacters(username string) error { + invalidChars := []rune{'"', '/', '[', ']', ':', ';', '|', '=', ',', '+', '*', '?', '<', '>', ' ', '`', '&', '\n'} + for _, char := range username { + for _, invalid := range invalidChars { + if char == invalid { + return fmt.Errorf("username contains invalid characters") + } + } + if char < 32 || char == 127 { + return fmt.Errorf("username contains control characters") + } + } + return nil +} + +// validateUsernameFormat checks for invalid username formats and patterns +func validateUsernameFormat(username string) error { + if username == "." || username == ".." { + return fmt.Errorf("username cannot be '.' or '..'") + } + + if strings.HasSuffix(username, ".") { + return fmt.Errorf("username cannot end with a period") + } + + return nil +} + +// createExecutorCommand creates a command using Windows executor for privilege dropping. +// Returns the command and a cleanup function that must be called after starting the process. +func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) + + username, _ := s.parseUsername(localUser.Username) + if err := validateUsername(username); err != nil { + return nil, nil, fmt.Errorf("invalid username %q: %w", username, err) + } + + return s.createUserSwitchCommand(localUser, session, hasPty) +} + +// createUserSwitchCommand creates a command with Windows user switching. +// Returns the command and a cleanup function that must be called after starting the process. +func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) { + username, domain := s.parseUsername(localUser.Username) + + shell := getUserShell(localUser.Uid) + + rawCmd := session.RawCommand() + var command string + if rawCmd != "" { + command = rawCmd + } + + config := WindowsExecutorConfig{ + Username: username, + Domain: domain, + WorkingDir: localUser.HomeDir, + Shell: shell, + Command: command, + Interactive: interactive || (rawCmd == ""), + } + + dropper := NewPrivilegeDropper() + cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config) + if err != nil { + return nil, nil, err + } + + cleanup := func() { + if token != 0 { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + log.Debugf("close primary token: %v", err) + } + } + } + + return cmd, cleanup, nil +} + +// parseUsername extracts username and domain from a Windows username +func (s *Server) parseUsername(fullUsername string) (username, domain string) { + // Handle DOMAIN\username format + if idx := strings.LastIndex(fullUsername, `\`); idx != -1 { + domain = fullUsername[:idx] + username = fullUsername[idx+1:] + return username, domain + } + + // Handle username@domain format + if username, domain, ok := strings.Cut(fullUsername, "@"); ok { + return username, domain + } + + // Local user (no domain) + return fullUsername, "." +} + +// hasPrivilege checks if the current process has a specific privilege +func hasPrivilege(token windows.Handle, privilegeName string) (bool, error) { + var luid windows.LUID + if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil { + return false, fmt.Errorf("lookup privilege value: %w", err) + } + + var returnLength uint32 + err := windows.GetTokenInformation( + windows.Token(token), + windows.TokenPrivileges, + nil, // null buffer to get size + 0, + &returnLength, + ) + + if err != nil && !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) { + return false, fmt.Errorf("get token information size: %w", err) + } + + buffer := make([]byte, returnLength) + err = windows.GetTokenInformation( + windows.Token(token), + windows.TokenPrivileges, + &buffer[0], + returnLength, + &returnLength, + ) + if err != nil { + return false, fmt.Errorf("get token information: %w", err) + } + + privileges := (*windows.Tokenprivileges)(unsafe.Pointer(&buffer[0])) + + // Check if the privilege is present and enabled + for i := uint32(0); i < privileges.PrivilegeCount; i++ { + privilege := (*windows.LUIDAndAttributes)(unsafe.Pointer( + uintptr(unsafe.Pointer(&privileges.Privileges[0])) + + uintptr(i)*unsafe.Sizeof(windows.LUIDAndAttributes{}), + )) + if privilege.Luid == luid { + return (privilege.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0, nil + } + } + + return false, nil +} + +// enablePrivilege enables a specific privilege for the current process token +// This is required because privileges like SeAssignPrimaryTokenPrivilege are present +// but disabled by default, even for the SYSTEM account +func enablePrivilege(token windows.Handle, privilegeName string) error { + var luid windows.LUID + if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil { + return fmt.Errorf("lookup privilege value for %s: %w", privilegeName, err) + } + + privileges := windows.Tokenprivileges{ + PrivilegeCount: 1, + Privileges: [1]windows.LUIDAndAttributes{ + { + Luid: luid, + Attributes: windows.SE_PRIVILEGE_ENABLED, + }, + }, + } + + err := windows.AdjustTokenPrivileges( + windows.Token(token), + false, + &privileges, + 0, + nil, + nil, + ) + if err != nil { + return fmt.Errorf("adjust token privileges for %s: %w", privilegeName, err) + } + + hasPriv, err := hasPrivilege(token, privilegeName) + if err != nil { + return fmt.Errorf("verify privilege %s after enabling: %w", privilegeName, err) + } + if !hasPriv { + return fmt.Errorf("privilege %s could not be enabled (may not be granted to account)", privilegeName) + } + + log.Debugf("Successfully enabled privilege %s for current process", privilegeName) + return nil +} + +// enableUserSwitching enables required privileges for Windows user switching +func enableUserSwitching() error { + process := windows.CurrentProcess() + + var token windows.Token + err := windows.OpenProcessToken( + process, + windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, + &token, + ) + if err != nil { + return fmt.Errorf("open process token: %w", err) + } + defer func() { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + log.Debugf("Failed to close process token: %v", err) + } + }() + + if err := enablePrivilege(windows.Handle(token), "SeAssignPrimaryTokenPrivilege"); err != nil { + return fmt.Errorf("enable SeAssignPrimaryTokenPrivilege: %w", err) + } + log.Infof("Windows user switching privileges enabled successfully") + return nil +} diff --git a/client/ssh/server/winpty/conpty.go b/client/ssh/server/winpty/conpty.go new file mode 100644 index 000000000..0f3659ffe --- /dev/null +++ b/client/ssh/server/winpty/conpty.go @@ -0,0 +1,487 @@ +//go:build windows + +package winpty + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "sync" + "syscall" + "unsafe" + + "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +var ( + ErrEmptyEnvironment = errors.New("empty environment") +) + +const ( + extendedStartupInfoPresent = 0x00080000 + createUnicodeEnvironment = 0x00000400 + procThreadAttributePseudoConsole = 0x00020016 + + PowerShellCommandFlag = "-Command" + + errCloseInputRead = "close input read handle: %v" + errCloseConPtyCleanup = "close ConPty handle during cleanup" +) + +// PtyConfig holds configuration for Pty execution. +type PtyConfig struct { + Shell string + Command string + Width int + Height int + WorkingDir string +} + +// UserConfig holds user execution configuration. +type UserConfig struct { + Token windows.Handle + Environment []string +} + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") + procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList") + procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute") + procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList") +) + +// ExecutePtyWithUserToken executes a command with ConPty using user token. +func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { + args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command) + commandLine := buildCommandLine(args) + + config := ExecutionConfig{ + Pty: ptyConfig, + User: userConfig, + Session: session, + Context: ctx, + } + + return executeConPtyWithConfig(commandLine, config) +} + +// ExecutionConfig holds all execution configuration. +type ExecutionConfig struct { + Pty PtyConfig + User UserConfig + Session ssh.Session + Context context.Context +} + +// executeConPtyWithConfig creates ConPty and executes process with configuration. +func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error { + ctx := config.Context + session := config.Session + width := config.Pty.Width + height := config.Pty.Height + userToken := config.User.Token + userEnv := config.User.Environment + workingDir := config.Pty.WorkingDir + + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + if err != nil { + return fmt.Errorf("create ConPty pipes: %w", err) + } + + hPty, err := createConPty(width, height, inputRead, outputWrite) + if err != nil { + return fmt.Errorf("create ConPty: %w", err) + } + + primaryToken, err := duplicateToPrimaryToken(userToken) + if err != nil { + if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 { + log.Debugf(errCloseConPtyCleanup) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + return fmt.Errorf("duplicate to primary token: %w", err) + } + defer func() { + if err := windows.CloseHandle(primaryToken); err != nil { + log.Debugf("close primary token: %v", err) + } + }() + + siEx, err := setupConPtyStartupInfo(hPty) + if err != nil { + if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 { + log.Debugf(errCloseConPtyCleanup) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + return fmt.Errorf("setup startup info: %w", err) + } + defer func() { + _, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList))) + }() + + pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx) + if err != nil { + if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 { + log.Debugf(errCloseConPtyCleanup) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + return fmt.Errorf("create process as user with ConPty: %w", err) + } + defer closeProcessInfo(pi) + + if err := windows.CloseHandle(inputRead); err != nil { + log.Debugf(errCloseInputRead, err) + } + if err := windows.CloseHandle(outputWrite); err != nil { + log.Debugf("close output write handle: %v", err) + } + + return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process) +} + +// createConPtyPipes creates input/output pipes for ConPty. +func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) { + if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil { + return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err) + } + + if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil { + if closeErr := windows.CloseHandle(inputRead); closeErr != nil { + log.Debugf(errCloseInputRead, closeErr) + } + if closeErr := windows.CloseHandle(inputWrite); closeErr != nil { + log.Debugf("close input write handle: %v", closeErr) + } + return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err) + } + + return inputRead, inputWrite, outputRead, outputWrite, nil +} + +// createConPty creates a Windows ConPty with the specified size and pipe handles. +func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) { + size := windows.Coord{X: int16(width), Y: int16(height)} + + var hPty windows.Handle + if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil { + return 0, fmt.Errorf("CreatePseudoConsole: %w", err) + } + + return hPty, nil +} + +// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes. +func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) { + var siEx windows.StartupInfoEx + siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx)) + + var attrListSize uintptr + ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize))) + if ret == 0 && attrListSize == 0 { + return nil, fmt.Errorf("get attribute list size") + } + + attrListBytes := make([]byte, attrListSize) + siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0])) + + ret, _, err := procInitializeProcThreadAttributeList.Call( + uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)), + 1, + 0, + uintptr(unsafe.Pointer(&attrListSize)), + ) + if ret == 0 { + return nil, fmt.Errorf("initialize attribute list: %w", err) + } + + ret, _, err = procUpdateProcThreadAttribute.Call( + uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)), + 0, + procThreadAttributePseudoConsole, + uintptr(hPty), + unsafe.Sizeof(hPty), + 0, + 0, + ) + if ret == 0 { + return nil, fmt.Errorf("update thread attribute: %w", err) + } + + return &siEx, nil +} + +// createConPtyProcess creates the actual process with ConPty. +func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) { + var pi windows.ProcessInformation + creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment) + + commandLinePtr, err := windows.UTF16PtrFromString(commandLine) + if err != nil { + return nil, fmt.Errorf("convert command line to UTF16: %w", err) + } + + envPtr, err := convertEnvironmentToUTF16(userEnv) + if err != nil { + return nil, err + } + + var workingDirPtr *uint16 + if workingDir != "" { + workingDirPtr, err = windows.UTF16PtrFromString(workingDir) + if err != nil { + return nil, fmt.Errorf("convert working directory to UTF16: %w", err) + } + } + + siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES + siEx.StartupInfo.StdInput = windows.Handle(0) + siEx.StartupInfo.StdOutput = windows.Handle(0) + siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput + + if userToken != windows.InvalidHandle { + err = windows.CreateProcessAsUser( + windows.Token(userToken), + nil, + commandLinePtr, + nil, + nil, + true, + creationFlags, + envPtr, + workingDirPtr, + &siEx.StartupInfo, + &pi, + ) + } else { + err = windows.CreateProcess( + nil, + commandLinePtr, + nil, + nil, + true, + creationFlags, + envPtr, + workingDirPtr, + &siEx.StartupInfo, + &pi, + ) + } + + if err != nil { + return nil, fmt.Errorf("create process: %w", err) + } + + return &pi, nil +} + +// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format. +func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) { + if len(userEnv) == 0 { + // Return nil pointer for empty environment - Windows API will inherit parent environment + return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment + } + + var envUTF16 []uint16 + for _, envVar := range userEnv { + if envVar != "" { + utf16Str, err := windows.UTF16FromString(envVar) + if err != nil { + log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err) + continue + } + envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...) + envUTF16 = append(envUTF16, 0) + } + } + envUTF16 = append(envUTF16, 0) + + if len(envUTF16) > 0 { + return &envUTF16[0], nil + } + // Return nil pointer when no valid environment variables found + return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment +} + +// duplicateToPrimaryToken converts an impersonation token to a primary token. +func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) { + var primaryToken windows.Handle + if err := windows.DuplicateTokenEx( + windows.Token(token), + windows.TOKEN_ALL_ACCESS, + nil, + windows.SecurityImpersonation, + windows.TokenPrimary, + (*windows.Token)(&primaryToken), + ); err != nil { + return 0, fmt.Errorf("duplicate token: %w", err) + } + return primaryToken, nil +} + +// SessionExiter provides the Exit method for reporting process exit status. +type SessionExiter interface { + Exit(code int) error +} + +// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers. +func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error { + if err := ctx.Err(); err != nil { + return err + } + + var wg sync.WaitGroup + startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer) + + processErr := waitForProcess(ctx, process) + if processErr != nil { + return processErr + } + + var exitCode uint32 + if err := windows.GetExitCodeProcess(process, &exitCode); err != nil { + log.Debugf("get exit code: %v", err) + } else { + if err := session.Exit(int(exitCode)); err != nil { + log.Debugf("report exit code: %v", err) + } + } + + // Clean up in the original order after process completes + if err := reader.Close(); err != nil { + log.Debugf("close reader: %v", err) + } + + ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)) + if ret == 0 { + log.Debugf("close ConPty handle: %v", err) + } + + wg.Wait() + + if err := windows.CloseHandle(outputRead); err != nil { + log.Debugf("close output read handle: %v", err) + } + + return nil +} + +// startIOBridging starts the I/O bridging goroutines. +func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) { + wg.Add(2) + + // Input: reader (SSH session) -> inputWrite (ConPty) + go func() { + defer wg.Done() + defer func() { + if err := windows.CloseHandle(inputWrite); err != nil { + log.Debugf("close input write handle in goroutine: %v", err) + } + }() + + if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil { + log.Debugf("input copy ended with error: %v", err) + } + }() + + // Output: outputRead (ConPty) -> writer (SSH session) + go func() { + defer wg.Done() + if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil { + log.Debugf("output copy ended with error: %v", err) + } + }() +} + +// waitForProcess waits for process completion with context cancellation. +func waitForProcess(ctx context.Context, process windows.Handle) error { + if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil { + return fmt.Errorf("wait for process %d: %w", process, err) + } + return nil +} + +// buildShellArgs builds shell arguments for ConPty execution. +func buildShellArgs(shell, command string) []string { + if command != "" { + return []string{shell, PowerShellCommandFlag, command} + } + return []string{shell} +} + +// buildCommandLine builds a Windows command line from arguments using proper escaping. +func buildCommandLine(args []string) string { + if len(args) == 0 { + return "" + } + + var result strings.Builder + for i, arg := range args { + if i > 0 { + result.WriteString(" ") + } + result.WriteString(syscall.EscapeArg(arg)) + } + return result.String() +} + +// closeHandles closes multiple Windows handles. +func closeHandles(handles ...windows.Handle) { + for _, handle := range handles { + if handle != windows.InvalidHandle { + if err := windows.CloseHandle(handle); err != nil { + log.Debugf("close handle: %v", err) + } + } + } +} + +// closeProcessInfo closes process and thread handles. +func closeProcessInfo(pi *windows.ProcessInformation) { + if pi != nil { + if err := windows.CloseHandle(pi.Process); err != nil { + log.Debugf("close process handle: %v", err) + } + if err := windows.CloseHandle(pi.Thread); err != nil { + log.Debugf("close thread handle: %v", err) + } + } +} + +// windowsHandleReader wraps a Windows handle for reading. +type windowsHandleReader struct { + handle windows.Handle +} + +func (r *windowsHandleReader) Read(p []byte) (n int, err error) { + var bytesRead uint32 + if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil { + return 0, err + } + return int(bytesRead), nil +} + +func (r *windowsHandleReader) Close() error { + return windows.CloseHandle(r.handle) +} + +// windowsHandleWriter wraps a Windows handle for writing. +type windowsHandleWriter struct { + handle windows.Handle +} + +func (w *windowsHandleWriter) Write(p []byte) (n int, err error) { + var bytesWritten uint32 + if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil { + return 0, err + } + return int(bytesWritten), nil +} + +func (w *windowsHandleWriter) Close() error { + return windows.CloseHandle(w.handle) +} diff --git a/client/ssh/server/winpty/conpty_test.go b/client/ssh/server/winpty/conpty_test.go new file mode 100644 index 000000000..4f04e1fad --- /dev/null +++ b/client/ssh/server/winpty/conpty_test.go @@ -0,0 +1,290 @@ +//go:build windows + +package winpty + +import ( + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" +) + +func TestBuildShellArgs(t *testing.T) { + tests := []struct { + name string + shell string + command string + expected []string + }{ + { + name: "Shell with command", + shell: "powershell.exe", + command: "Get-Process", + expected: []string{"powershell.exe", "-Command", "Get-Process"}, + }, + { + name: "CMD with command", + shell: "cmd.exe", + command: "dir", + expected: []string{"cmd.exe", "-Command", "dir"}, + }, + { + name: "Shell interactive", + shell: "powershell.exe", + command: "", + expected: []string{"powershell.exe"}, + }, + { + name: "CMD interactive", + shell: "cmd.exe", + command: "", + expected: []string{"cmd.exe"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildShellArgs(tt.shell, tt.command) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBuildCommandLine(t *testing.T) { + tests := []struct { + name string + args []string + expected string + }{ + { + name: "Simple args", + args: []string{"cmd.exe", "/c", "echo"}, + expected: "cmd.exe /c echo", + }, + { + name: "Args with spaces", + args: []string{"Program Files\\app.exe", "arg with spaces"}, + expected: `"Program Files\app.exe" "arg with spaces"`, + }, + { + name: "Args with quotes", + args: []string{"cmd.exe", "/c", `echo "hello world"`}, + expected: `cmd.exe /c "echo \"hello world\""`, + }, + { + name: "PowerShell calling PowerShell", + args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`}, + expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`, + }, + { + name: "Complex nested quotes", + args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`}, + expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`, + }, + { + name: "Path with spaces and args", + args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`}, + expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`, + }, + { + name: "Empty argument", + args: []string{"cmd.exe", "/c", "echo", ""}, + expected: `cmd.exe /c echo ""`, + }, + { + name: "Argument with backslashes", + args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"}, + expected: `robocopy C:\Source\ C:\Dest\ /E`, + }, + { + name: "Empty args", + args: []string{}, + expected: "", + }, + { + name: "Single arg with space", + args: []string{"path with spaces"}, + expected: `"path with spaces"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildCommandLine(tt.args) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCreateConPtyPipes(t *testing.T) { + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + require.NoError(t, err, "Should create ConPty pipes successfully") + + // Verify all handles are valid + assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid") + assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid") + assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid") + assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid") + + // Clean up handles + closeHandles(inputRead, inputWrite, outputRead, outputWrite) +} + +func TestCreateConPty(t *testing.T) { + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + require.NoError(t, err, "Should create ConPty pipes successfully") + defer closeHandles(inputRead, inputWrite, outputRead, outputWrite) + + hPty, err := createConPty(80, 24, inputRead, outputWrite) + require.NoError(t, err, "Should create ConPty successfully") + assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid") + + // Clean up ConPty + ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty)) + assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully") +} + +func TestConvertEnvironmentToUTF16(t *testing.T) { + tests := []struct { + name string + userEnv []string + hasError bool + }{ + { + name: "Valid environment variables", + userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"}, + hasError: false, + }, + { + name: "Empty environment", + userEnv: []string{}, + hasError: false, + }, + { + name: "Environment with empty strings", + userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"}, + hasError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := convertEnvironmentToUTF16(tt.userEnv) + if tt.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if len(tt.userEnv) == 0 { + assert.Nil(t, result, "Empty environment should return nil") + } else { + assert.NotNil(t, result, "Non-empty environment should return valid pointer") + } + } + }) + } +} + +func TestDuplicateToPrimaryToken(t *testing.T) { + if testing.Short() { + t.Skip("Skipping token tests in short mode") + } + + // Get current process token for testing + var token windows.Token + err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token) + require.NoError(t, err, "Should open current process token") + defer func() { + if err := windows.CloseHandle(windows.Handle(token)); err != nil { + t.Logf("Failed to close token: %v", err) + } + }() + + primaryToken, err := duplicateToPrimaryToken(windows.Handle(token)) + require.NoError(t, err, "Should duplicate token to primary") + assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid") + + // Clean up + err = windows.CloseHandle(primaryToken) + assert.NoError(t, err, "Should close primary token") +} + +func TestWindowsHandleReader(t *testing.T) { + // Create a pipe for testing + var readHandle, writeHandle windows.Handle + err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0) + require.NoError(t, err, "Should create pipe for testing") + defer closeHandles(readHandle, writeHandle) + + // Write test data + testData := []byte("Hello, Windows Handle Reader!") + var bytesWritten uint32 + err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil) + require.NoError(t, err, "Should write test data") + require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data") + + // Close write handle to signal EOF + if err := windows.CloseHandle(writeHandle); err != nil { + t.Fatalf("Should close write handle: %v", err) + } + writeHandle = windows.InvalidHandle + + // Test reading + reader := &windowsHandleReader{handle: readHandle} + buffer := make([]byte, len(testData)) + n, err := reader.Read(buffer) + require.NoError(t, err, "Should read from handle") + assert.Equal(t, len(testData), n, "Should read expected number of bytes") + assert.Equal(t, testData, buffer, "Should read expected data") +} + +func TestWindowsHandleWriter(t *testing.T) { + // Create a pipe for testing + var readHandle, writeHandle windows.Handle + err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0) + require.NoError(t, err, "Should create pipe for testing") + defer closeHandles(readHandle, writeHandle) + + // Test writing + testData := []byte("Hello, Windows Handle Writer!") + writer := &windowsHandleWriter{handle: writeHandle} + n, err := writer.Write(testData) + require.NoError(t, err, "Should write to handle") + assert.Equal(t, len(testData), n, "Should write expected number of bytes") + + // Close write handle + if err := windows.CloseHandle(writeHandle); err != nil { + t.Fatalf("Should close write handle: %v", err) + } + + // Verify data was written by reading it back + buffer := make([]byte, len(testData)) + var bytesRead uint32 + err = windows.ReadFile(readHandle, buffer, &bytesRead, nil) + require.NoError(t, err, "Should read back written data") + assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes") + assert.Equal(t, testData, buffer, "Should read back expected data") +} + +// BenchmarkConPtyCreation benchmarks ConPty creation performance +func BenchmarkConPtyCreation(b *testing.B) { + for i := 0; i < b.N; i++ { + inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes() + if err != nil { + b.Fatal(err) + } + + hPty, err := createConPty(80, 24, inputRead, outputWrite) + if err != nil { + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + b.Fatal(err) + } + + // Clean up + if ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)); ret == 0 { + log.Debugf("ClosePseudoConsole failed: %v", err) + } + closeHandles(inputRead, inputWrite, outputRead, outputWrite) + } +} diff --git a/client/ssh/server_mock.go b/client/ssh/server_mock.go deleted file mode 100644 index 76f43fd4e..000000000 --- a/client/ssh/server_mock.go +++ /dev/null @@ -1,46 +0,0 @@ -//go:build !js - -package ssh - -import "context" - -// MockServer mocks ssh.Server -type MockServer struct { - Ctx context.Context - StopFunc func() error - StartFunc func() error - AddAuthorizedKeyFunc func(peer, newKey string) error - RemoveAuthorizedKeyFunc func(peer string) -} - -// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys -func (srv *MockServer) RemoveAuthorizedKey(peer string) { - if srv.RemoveAuthorizedKeyFunc == nil { - return - } - srv.RemoveAuthorizedKeyFunc(peer) -} - -// AddAuthorizedKey add a given peer key to server authorized keys -func (srv *MockServer) AddAuthorizedKey(peer, newKey string) error { - if srv.AddAuthorizedKeyFunc == nil { - return nil - } - return srv.AddAuthorizedKeyFunc(peer, newKey) -} - -// Stop stops SSH server. -func (srv *MockServer) Stop() error { - if srv.StopFunc == nil { - return nil - } - return srv.StopFunc() -} - -// Start starts SSH server. Blocking -func (srv *MockServer) Start() error { - if srv.StartFunc == nil { - return nil - } - return srv.StartFunc() -} diff --git a/client/ssh/server_test.go b/client/ssh/server_test.go deleted file mode 100644 index 1f310c2bb..000000000 --- a/client/ssh/server_test.go +++ /dev/null @@ -1,123 +0,0 @@ -//go:build !js - -package ssh - -import ( - "fmt" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/ssh" - "strings" - "testing" -) - -func TestServer_AddAuthorizedKey(t *testing.T) { - key, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - server, err := newDefaultServer(key, "localhost:") - if err != nil { - t.Fatal(err) - } - - // add multiple keys - keys := map[string][]byte{} - for i := 0; i < 10; i++ { - peer := fmt.Sprintf("%s-%d", "remotePeer", i) - remotePrivKey, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - remotePubKey, err := GeneratePublicKey(remotePrivKey) - if err != nil { - t.Fatal(err) - } - - err = server.AddAuthorizedKey(peer, string(remotePubKey)) - if err != nil { - t.Error(err) - } - keys[peer] = remotePubKey - } - - // make sure that all keys have been added - for peer, remotePubKey := range keys { - k, ok := server.authorizedKeys[peer] - assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys") - - assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(ssh.MarshalAuthorizedKey(k)))) - } - -} - -func TestServer_RemoveAuthorizedKey(t *testing.T) { - key, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - server, err := newDefaultServer(key, "localhost:") - if err != nil { - t.Fatal(err) - } - - remotePrivKey, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - remotePubKey, err := GeneratePublicKey(remotePrivKey) - if err != nil { - t.Fatal(err) - } - - err = server.AddAuthorizedKey("remotePeer", string(remotePubKey)) - if err != nil { - t.Error(err) - } - - server.RemoveAuthorizedKey("remotePeer") - - _, ok := server.authorizedKeys["remotePeer"] - assert.False(t, ok, "expecting remotePeer's SSH key to be removed") -} - -func TestServer_PubKeyHandler(t *testing.T) { - key, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - server, err := newDefaultServer(key, "localhost:") - if err != nil { - t.Fatal(err) - } - - var keys []ssh.PublicKey - for i := 0; i < 10; i++ { - peer := fmt.Sprintf("%s-%d", "remotePeer", i) - remotePrivKey, err := GeneratePrivateKey(ED25519) - if err != nil { - t.Fatal(err) - } - remotePubKey, err := GeneratePublicKey(remotePrivKey) - if err != nil { - t.Fatal(err) - } - - remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey) - if err != nil { - t.Fatal(err) - } - - err = server.AddAuthorizedKey(peer, string(remotePubKey)) - if err != nil { - t.Error(err) - } - keys = append(keys, remoteParsedPubKey) - } - - for _, key := range keys { - accepted := server.publicKeyHandler(nil, key) - - assert.Truef(t, accepted, "expecting SSH connection to be accepted for a given SSH key %s", string(ssh.MarshalAuthorizedKey(key))) - } - -} diff --git a/client/ssh/util.go b/client/ssh/ssh.go similarity index 86% rename from client/ssh/util.go rename to client/ssh/ssh.go index a54a609bc..c0024c599 100644 --- a/client/ssh/util.go +++ b/client/ssh/ssh.go @@ -32,9 +32,8 @@ const RSA KeyType = "rsa" // RSAKeySize is a size of newly generated RSA key const RSAKeySize = 2048 -// GeneratePrivateKey creates RSA Private Key of specified byte size +// GeneratePrivateKey creates a private key of the specified type. func GeneratePrivateKey(keyType KeyType) ([]byte, error) { - var key crypto.Signer var err error switch keyType { @@ -59,7 +58,7 @@ func GeneratePrivateKey(keyType KeyType) ([]byte, error) { return pemBytes, nil } -// GeneratePublicKey returns the public part of the private key +// GeneratePublicKey returns the public part of the private key. func GeneratePublicKey(key []byte) ([]byte, error) { signer, err := gossh.ParsePrivateKey(key) if err != nil { @@ -70,20 +69,17 @@ func GeneratePublicKey(key []byte) ([]byte, error) { return []byte(strKey), nil } -// EncodePrivateKeyToPEM encodes Private Key from RSA to PEM format +// EncodePrivateKeyToPEM encodes a private key to PEM format. func EncodePrivateKeyToPEM(privateKey crypto.Signer) ([]byte, error) { mk, err := x509.MarshalPKCS8PrivateKey(privateKey) if err != nil { return nil, err } - // pem.Block privBlock := pem.Block{ Type: "PRIVATE KEY", Bytes: mk, } - - // Private key in PEM format privatePEM := pem.EncodeToMemory(&privBlock) return privatePEM, nil } diff --git a/client/ssh/testutil/user_helpers.go b/client/ssh/testutil/user_helpers.go new file mode 100644 index 000000000..0c1222078 --- /dev/null +++ b/client/ssh/testutil/user_helpers.go @@ -0,0 +1,172 @@ +package testutil + +import ( + "fmt" + "log" + "os" + "os/exec" + "os/user" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +var testCreatedUsers = make(map[string]bool) +var testUsersToCleanup []string + +// GetTestUsername returns an appropriate username for testing +func GetTestUsername(t *testing.T) string { + if runtime.GOOS == "windows" { + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + if IsSystemAccount(currentUser.Username) { + if IsCI() { + if testUser := GetOrCreateTestUser(t); testUser != "" { + return testUser + } + } else { + if _, err := user.Lookup("Administrator"); err == nil { + return "Administrator" + } + if testUser := GetOrCreateTestUser(t); testUser != "" { + return testUser + } + } + } + return currentUser.Username + } + + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + return currentUser.Username +} + +// IsCI checks if we're running in a CI environment +func IsCI() bool { + if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" { + return true + } + + hostname, err := os.Hostname() + if err == nil && strings.HasPrefix(hostname, "runner") { + return true + } + + return false +} + +// IsSystemAccount checks if the user is a system account that can't authenticate +func IsSystemAccount(username string) bool { + systemAccounts := []string{ + "system", + "NT AUTHORITY\\SYSTEM", + "NT AUTHORITY\\LOCAL SERVICE", + "NT AUTHORITY\\NETWORK SERVICE", + } + + for _, sysAccount := range systemAccounts { + if strings.EqualFold(username, sysAccount) { + return true + } + } + return false +} + +// RegisterTestUserCleanup registers a test user for cleanup +func RegisterTestUserCleanup(username string) { + if !testCreatedUsers[username] { + testCreatedUsers[username] = true + testUsersToCleanup = append(testUsersToCleanup, username) + } +} + +// CleanupTestUsers removes all created test users +func CleanupTestUsers() { + for _, username := range testUsersToCleanup { + RemoveWindowsTestUser(username) + } + testUsersToCleanup = nil + testCreatedUsers = make(map[string]bool) +} + +// GetOrCreateTestUser creates a test user on Windows if needed +func GetOrCreateTestUser(t *testing.T) string { + testUsername := "netbird-test-user" + + if _, err := user.Lookup(testUsername); err == nil { + return testUsername + } + + if CreateWindowsTestUser(t, testUsername) { + RegisterTestUserCleanup(testUsername) + return testUsername + } + + return "" +} + +// RemoveWindowsTestUser removes a local user on Windows using PowerShell +func RemoveWindowsTestUser(username string) { + if runtime.GOOS != "windows" { + return + } + + psCmd := fmt.Sprintf(` + try { + Remove-LocalUser -Name "%s" -ErrorAction Stop + Write-Output "User removed successfully" + } catch { + if ($_.Exception.Message -like "*cannot be found*") { + Write-Output "User not found (already removed)" + } else { + Write-Error $_.Exception.Message + } + } + `, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output)) + } else { + log.Printf("Test user %s cleanup result: %s", username, string(output)) + } +} + +// CreateWindowsTestUser creates a local user on Windows using PowerShell +func CreateWindowsTestUser(t *testing.T, username string) bool { + if runtime.GOOS != "windows" { + return false + } + + psCmd := fmt.Sprintf(` + try { + $password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force + New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires + Add-LocalGroupMember -Group "Users" -Member "%s" + Write-Output "User created successfully" + } catch { + if ($_.Exception.Message -like "*already exists*") { + Write-Output "User already exists" + } else { + Write-Error $_.Exception.Message + exit 1 + } + } + `, username, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + t.Logf("Failed to create test user: %v, output: %s", err, string(output)) + return false + } + + t.Logf("Test user creation result: %s", string(output)) + return true +} diff --git a/client/ssh/window_freebsd.go b/client/ssh/window_freebsd.go deleted file mode 100644 index ef4848341..000000000 --- a/client/ssh/window_freebsd.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build freebsd - -package ssh - -import ( - "os" -) - -func setWinSize(file *os.File, width, height int) { -} diff --git a/client/ssh/window_unix.go b/client/ssh/window_unix.go deleted file mode 100644 index 2891eb70e..000000000 --- a/client/ssh/window_unix.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build linux || darwin - -package ssh - -import ( - "os" - "syscall" - "unsafe" -) - -func setWinSize(file *os.File, width, height int) { - syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TIOCSWINSZ), //nolint - uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(height), uint16(width), 0, 0}))) -} diff --git a/client/ssh/window_windows.go b/client/ssh/window_windows.go deleted file mode 100644 index 5abd41f27..000000000 --- a/client/ssh/window_windows.go +++ /dev/null @@ -1,9 +0,0 @@ -package ssh - -import ( - "os" -) - -func setWinSize(file *os.File, width, height int) { - -} diff --git a/client/status/status.go b/client/status/status.go index 8a0b7bae0..d975f0e29 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -81,6 +81,18 @@ type NsServerGroupStateOutput struct { Error string `json:"error" yaml:"error"` } +type SSHSessionOutput struct { + Username string `json:"username" yaml:"username"` + RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"` + Command string `json:"command" yaml:"command"` + JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"` +} + +type SSHServerStateOutput struct { + Enabled bool `json:"enabled" yaml:"enabled"` + Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"` +} + type OutputOverview struct { Peers PeersStateOutput `json:"peers" yaml:"peers"` CliVersion string `json:"cliVersion" yaml:"cliVersion"` @@ -100,6 +112,7 @@ type OutputOverview struct { Events []SystemEventOutput `json:"events" yaml:"events"` LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"` ProfileName string `json:"profileName" yaml:"profileName"` + SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"` } func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview { @@ -121,6 +134,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status relayOverview := mapRelays(pbFullStatus.GetRelays()) peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) + sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState()) overview := OutputOverview{ Peers: peersOverview, @@ -141,6 +155,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status Events: mapEvents(pbFullStatus.GetEvents()), LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(), ProfileName: profName, + SSHServerState: sshServerOverview, } if anon { @@ -190,6 +205,30 @@ func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput { return mappedNSGroups } +func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput { + if sshServerState == nil { + return SSHServerStateOutput{ + Enabled: false, + Sessions: []SSHSessionOutput{}, + } + } + + sessions := make([]SSHSessionOutput, 0, len(sshServerState.GetSessions())) + for _, session := range sshServerState.GetSessions() { + sessions = append(sessions, SSHSessionOutput{ + Username: session.GetUsername(), + RemoteAddress: session.GetRemoteAddress(), + Command: session.GetCommand(), + JWTUsername: session.GetJwtUsername(), + }) + } + + return SSHServerStateOutput{ + Enabled: sshServerState.GetEnabled(), + Sessions: sessions, + } +} + func mapPeers( peers []*proto.PeerState, statusFilter string, @@ -300,7 +339,7 @@ func ParseToYAML(overview OutputOverview) (string, error) { return string(yamlBytes), nil } -func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string { +func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string { var managementConnString string if overview.ManagementState.Connected { managementConnString = "Connected" @@ -405,6 +444,41 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, lazyConnectionEnabledStatus = "true" } + sshServerStatus := "Disabled" + if overview.SSHServerState.Enabled { + sessionCount := len(overview.SSHServerState.Sessions) + if sessionCount > 0 { + sessionWord := "session" + if sessionCount > 1 { + sessionWord = "sessions" + } + sshServerStatus = fmt.Sprintf("Enabled (%d active %s)", sessionCount, sessionWord) + } else { + sshServerStatus = "Enabled" + } + + if showSSHSessions && sessionCount > 0 { + for _, session := range overview.SSHServerState.Sessions { + var sessionDisplay string + if session.JWTUsername != "" { + sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s", + session.JWTUsername, + session.RemoteAddress, + session.Username, + session.Command, + ) + } else { + sessionDisplay = fmt.Sprintf("[%s@%s] %s", + session.Username, + session.RemoteAddress, + session.Command, + ) + } + sshServerStatus += "\n " + sessionDisplay + } + } + } + peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) goos := runtime.GOOS @@ -428,6 +502,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, "Interface type: %s\n"+ "Quantum resistance: %s\n"+ "Lazy connection: %s\n"+ + "SSH Server: %s\n"+ "Networks: %s\n"+ "Forwarding rules: %d\n"+ "Peers count: %s\n", @@ -444,6 +519,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, interfaceTypeString, rosenpassEnabledStatus, lazyConnectionEnabledStatus, + sshServerStatus, networks, overview.NumberOfForwardingRules, peersCountString, @@ -454,7 +530,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, func ParseToFullDetailSummary(overview OutputOverview) string { parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) parsedEventsString := parseEvents(overview.Events) - summary := ParseGeneralSummary(overview, true, true, true) + summary := ParseGeneralSummary(overview, true, true, true, true) return fmt.Sprintf( "Peers detail:"+ @@ -746,4 +822,13 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) { event.Metadata[k] = a.AnonymizeString(v) } } + + for i, session := range overview.SSHServerState.Sessions { + if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil { + overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port) + } else { + overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress) + } + overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command) + } } diff --git a/client/status/status_test.go b/client/status/status_test.go index 660efd9ef..1dca1e5b1 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -231,6 +231,10 @@ var overview = OutputOverview{ Networks: []string{ "10.10.0.0/24", }, + SSHServerState: SSHServerStateOutput{ + Enabled: false, + Sessions: []SSHSessionOutput{}, + }, } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { @@ -385,7 +389,11 @@ func TestParsingToJSON(t *testing.T) { ], "events": [], "lazyConnectionEnabled": false, - "profileName":"" + "profileName":"", + "sshServer":{ + "enabled":false, + "sessions":[] + } }` // @formatter:on @@ -488,6 +496,9 @@ dnsServers: events: [] lazyConnectionEnabled: false profileName: "" +sshServer: + enabled: false + sessions: [] ` assert.Equal(t, expectedYAML, yaml) @@ -554,6 +565,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Lazy connection: false +SSH Server: Disabled Networks: 10.10.0.0/24 Forwarding rules: 0 Peers count: 2/2 Connected @@ -563,7 +575,7 @@ Peers count: 2/2 Connected } func TestParsingToShortVersion(t *testing.T) { - shortVersion := ParseGeneralSummary(overview, false, false, false) + shortVersion := ParseGeneralSummary(overview, false, false, false, false) expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` Daemon version: 0.14.1 @@ -578,6 +590,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Lazy connection: false +SSH Server: Disabled Networks: 10.10.0.0/24 Forwarding rules: 0 Peers count: 2/2 Connected diff --git a/client/system/info.go b/client/system/info.go index a180be4c0..01176e765 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -72,6 +72,12 @@ type Info struct { BlockInbound bool LazyConnectionEnabled bool + + EnableSSHRoot bool + EnableSSHSFTP bool + EnableSSHLocalPortForwarding bool + EnableSSHRemotePortForwarding bool + DisableSSHAuth bool } func (i *Info) SetFlags( @@ -79,6 +85,8 @@ func (i *Info) SetFlags( serverSSHAllowed *bool, disableClientRoutes, disableServerRoutes, disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool, + enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool, + disableSSHAuth *bool, ) { i.RosenpassEnabled = rosenpassEnabled i.RosenpassPermissive = rosenpassPermissive @@ -94,6 +102,22 @@ func (i *Info) SetFlags( i.BlockInbound = blockInbound i.LazyConnectionEnabled = lazyConnectionEnabled + + if enableSSHRoot != nil { + i.EnableSSHRoot = *enableSSHRoot + } + if enableSSHSFTP != nil { + i.EnableSSHSFTP = *enableSSHSFTP + } + if enableSSHLocalPortForwarding != nil { + i.EnableSSHLocalPortForwarding = *enableSSHLocalPortForwarding + } + if enableSSHRemotePortForwarding != nil { + i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding + } + if disableSSHAuth != nil { + i.DisableSSHAuth = *disableSSHAuth + } } // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index d3ab42423..44643616d 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -55,6 +55,7 @@ const ( const ( censoredPreSharedKey = "**********" + maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds ) func main() { @@ -265,25 +266,38 @@ type serviceClient struct { iMTU *widget.Entry // switch elements for settings form - sRosenpassPermissive *widget.Check - sNetworkMonitor *widget.Check - sDisableDNS *widget.Check - sDisableClientRoutes *widget.Check - sDisableServerRoutes *widget.Check - sBlockLANAccess *widget.Check + sRosenpassPermissive *widget.Check + sNetworkMonitor *widget.Check + sDisableDNS *widget.Check + sDisableClientRoutes *widget.Check + sDisableServerRoutes *widget.Check + sBlockLANAccess *widget.Check + sEnableSSHRoot *widget.Check + sEnableSSHSFTP *widget.Check + sEnableSSHLocalPortForward *widget.Check + sEnableSSHRemotePortForward *widget.Check + sDisableSSHAuth *widget.Check + iSSHJWTCacheTTL *widget.Entry // observable settings over corresponding iMngURL and iPreSharedKey values. - managementURL string - preSharedKey string - RosenpassPermissive bool - interfaceName string - interfacePort int - mtu uint16 - networkMonitor bool - disableDNS bool - disableClientRoutes bool - disableServerRoutes bool - blockLANAccess bool + managementURL string + preSharedKey string + + RosenpassPermissive bool + interfaceName string + interfacePort int + mtu uint16 + networkMonitor bool + disableDNS bool + disableClientRoutes bool + disableServerRoutes bool + blockLANAccess bool + enableSSHRoot bool + enableSSHSFTP bool + enableSSHLocalPortForward bool + enableSSHRemotePortForward bool + disableSSHAuth bool + sshJWTCacheTTL int connected bool update *version.Update @@ -435,18 +449,22 @@ func (s *serviceClient) showSettingsUI() { s.sDisableClientRoutes = widget.NewCheck("This peer won't route traffic to other peers", nil) s.sDisableServerRoutes = widget.NewCheck("This peer won't act as router for others", nil) s.sBlockLANAccess = widget.NewCheck("Blocks local network access when used as exit node", nil) + s.sEnableSSHRoot = widget.NewCheck("Enable SSH Root Login", nil) + s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil) + s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil) + s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil) + s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil) + s.iSSHJWTCacheTTL = widget.NewEntry() s.wSettings.SetContent(s.getSettingsForm()) - s.wSettings.Resize(fyne.NewSize(600, 500)) + s.wSettings.Resize(fyne.NewSize(600, 400)) s.wSettings.SetFixedSize(true) s.getSrvConfig() s.wSettings.Show() } -// getSettingsForm to embed it into settings window. -func (s *serviceClient) getSettingsForm() *widget.Form { - +func (s *serviceClient) getConnectionForm() *widget.Form { var activeProfName string activeProf, err := s.profileManager.GetActiveProfile() if err != nil { @@ -457,153 +475,277 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return &widget.Form{ Items: []*widget.FormItem{ {Text: "Profile", Widget: widget.NewLabel(activeProfName)}, + {Text: "Management URL", Widget: s.iMngURL}, + {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, {Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive}, {Text: "Interface Name", Widget: s.iInterfaceName}, {Text: "Interface Port", Widget: s.iInterfacePort}, {Text: "MTU", Widget: s.iMTU}, - {Text: "Management URL", Widget: s.iMngURL}, - {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, {Text: "Log File", Widget: s.iLogFile}, + }, + } +} + +func (s *serviceClient) saveSettings() { + // Check if update settings are disabled by daemon + features, err := s.getFeatures() + if err != nil { + log.Errorf("failed to get features from daemon: %v", err) + // Continue with default behavior if features can't be retrieved + } else if features != nil && features.DisableUpdateSettings { + log.Warn("Configuration updates are disabled by daemon") + dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings) + return + } + + if err := s.validateSettings(); err != nil { + dialog.ShowError(err, s.wSettings) + return + } + + port, mtu, err := s.parseNumericSettings() + if err != nil { + dialog.ShowError(err, s.wSettings) + return + } + + iMngURL := strings.TrimSpace(s.iMngURL.Text) + + if s.hasSettingsChanged(iMngURL, port, mtu) { + if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil { + dialog.ShowError(err, s.wSettings) + return + } + } + + s.wSettings.Close() +} + +func (s *serviceClient) validateSettings() error { + if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { + if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { + return fmt.Errorf("Invalid Pre-shared Key Value") + } + } + return nil +} + +func (s *serviceClient) parseNumericSettings() (int64, int64, error) { + port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) + if err != nil { + return 0, 0, errors.New("Invalid interface port") + } + if port < 1 || port > 65535 { + return 0, 0, errors.New("Invalid interface port: out of range 1-65535") + } + + var mtu int64 + mtuText := strings.TrimSpace(s.iMTU.Text) + if mtuText != "" { + mtu, err = strconv.ParseInt(mtuText, 10, 64) + if err != nil { + return 0, 0, errors.New("Invalid MTU value") + } + if mtu < iface.MinMTU || mtu > iface.MaxMTU { + return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU) + } + } + + return port, mtu, nil +} + +func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool { + return s.managementURL != iMngURL || + s.preSharedKey != s.iPreSharedKey.Text || + s.RosenpassPermissive != s.sRosenpassPermissive.Checked || + s.interfaceName != s.iInterfaceName.Text || + s.interfacePort != int(port) || + s.mtu != uint16(mtu) || + s.networkMonitor != s.sNetworkMonitor.Checked || + s.disableDNS != s.sDisableDNS.Checked || + s.disableClientRoutes != s.sDisableClientRoutes.Checked || + s.disableServerRoutes != s.sDisableServerRoutes.Checked || + s.blockLANAccess != s.sBlockLANAccess.Checked || + s.hasSSHChanges() +} + +func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error { + s.managementURL = iMngURL + s.preSharedKey = s.iPreSharedKey.Text + s.mtu = uint16(mtu) + + req, err := s.buildSetConfigRequest(iMngURL, port, mtu) + if err != nil { + return fmt.Errorf("build config request: %w", err) + } + + if err := s.sendConfigUpdate(req); err != nil { + return fmt.Errorf("set configuration: %w", err) + } + + return nil +} + +func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) { + currUser, err := user.Current() + if err != nil { + return nil, fmt.Errorf("get current user: %w", err) + } + + activeProf, err := s.profileManager.GetActiveProfile() + if err != nil { + return nil, fmt.Errorf("get active profile: %w", err) + } + + req := &proto.SetConfigRequest{ + ProfileName: activeProf.Name, + Username: currUser.Username, + } + + if iMngURL != "" { + req.ManagementUrl = iMngURL + } + + req.RosenpassPermissive = &s.sRosenpassPermissive.Checked + req.InterfaceName = &s.iInterfaceName.Text + req.WireguardPort = &port + if mtu > 0 { + req.Mtu = &mtu + } + + req.NetworkMonitor = &s.sNetworkMonitor.Checked + req.DisableDns = &s.sDisableDNS.Checked + req.DisableClientRoutes = &s.sDisableClientRoutes.Checked + req.DisableServerRoutes = &s.sDisableServerRoutes.Checked + req.BlockLanAccess = &s.sBlockLANAccess.Checked + + req.EnableSSHRoot = &s.sEnableSSHRoot.Checked + req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked + req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked + req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked + req.DisableSSHAuth = &s.sDisableSSHAuth.Checked + + sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text) + if sshJWTCacheTTLText != "" { + sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32) + if err != nil { + return nil, errors.New("Invalid SSH JWT Cache TTL value") + } + if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL { + return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL) + } + sshJWTCacheTTL32 := int32(sshJWTCacheTTL) + req.SshJWTCacheTTL = &sshJWTCacheTTL32 + } + + if s.iPreSharedKey.Text != censoredPreSharedKey { + req.OptionalPreSharedKey = &s.iPreSharedKey.Text + } + + return req, nil +} + +func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error { + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + return fmt.Errorf("get client: %w", err) + } + + _, err = conn.SetConfig(s.ctx, req) + if err != nil { + return fmt.Errorf("set config: %w", err) + } + + // Reconnect if connected to apply the new settings + go func() { + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + log.Errorf("get service status: %v", err) + return + } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + } + + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + return + } + } + }() + + return nil +} + +func (s *serviceClient) getSettingsForm() fyne.CanvasObject { + connectionForm := s.getConnectionForm() + networkForm := s.getNetworkForm() + sshForm := s.getSSHForm() + tabs := container.NewAppTabs( + container.NewTabItem("Connection", connectionForm), + container.NewTabItem("Network", networkForm), + container.NewTabItem("SSH", sshForm), + ) + saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings) + saveButton.Importance = widget.HighImportance + cancelButton := widget.NewButtonWithIcon("Cancel", theme.CancelIcon(), func() { + s.wSettings.Close() + }) + buttonContainer := container.NewHBox( + layout.NewSpacer(), + cancelButton, + saveButton, + ) + return container.NewBorder(nil, buttonContainer, nil, nil, tabs) +} + +func (s *serviceClient) getNetworkForm() *widget.Form { + return &widget.Form{ + Items: []*widget.FormItem{ {Text: "Network Monitor", Widget: s.sNetworkMonitor}, {Text: "Disable DNS", Widget: s.sDisableDNS}, {Text: "Disable Client Routes", Widget: s.sDisableClientRoutes}, {Text: "Disable Server Routes", Widget: s.sDisableServerRoutes}, {Text: "Disable LAN Access", Widget: s.sBlockLANAccess}, }, - SubmitText: "Save", - OnSubmit: func() { - // Check if update settings are disabled by daemon - features, err := s.getFeatures() - if err != nil { - log.Errorf("failed to get features from daemon: %v", err) - // Continue with default behavior if features can't be retrieved - } else if features != nil && features.DisableUpdateSettings { - log.Warn("Configuration updates are disabled by daemon") - dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings) - return - } + } +} - if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { - // validate preSharedKey if it added - if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { - dialog.ShowError(fmt.Errorf("Invalid Pre-shared Key Value"), s.wSettings) - return - } - } - - port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) - if err != nil { - dialog.ShowError(errors.New("Invalid interface port"), s.wSettings) - return - } - - var mtu int64 - mtuText := strings.TrimSpace(s.iMTU.Text) - if mtuText != "" { - var err error - mtu, err = strconv.ParseInt(mtuText, 10, 64) - if err != nil { - dialog.ShowError(errors.New("Invalid MTU value"), s.wSettings) - return - } - if mtu < iface.MinMTU || mtu > iface.MaxMTU { - dialog.ShowError(fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU), s.wSettings) - return - } - } - - iMngURL := strings.TrimSpace(s.iMngURL.Text) - - defer s.wSettings.Close() - - // Check if any settings have changed - if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text || - s.RosenpassPermissive != s.sRosenpassPermissive.Checked || - s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) || - s.mtu != uint16(mtu) || - s.networkMonitor != s.sNetworkMonitor.Checked || - s.disableDNS != s.sDisableDNS.Checked || - s.disableClientRoutes != s.sDisableClientRoutes.Checked || - s.disableServerRoutes != s.sDisableServerRoutes.Checked || - s.blockLANAccess != s.sBlockLANAccess.Checked { - - s.managementURL = iMngURL - s.preSharedKey = s.iPreSharedKey.Text - s.mtu = uint16(mtu) - - currUser, err := user.Current() - if err != nil { - log.Errorf("get current user: %v", err) - return - } - - var req proto.SetConfigRequest - req.ProfileName = activeProf.Name - req.Username = currUser.Username - - if iMngURL != "" { - req.ManagementUrl = iMngURL - } - - req.RosenpassPermissive = &s.sRosenpassPermissive.Checked - req.InterfaceName = &s.iInterfaceName.Text - req.WireguardPort = &port - if mtu > 0 { - req.Mtu = &mtu - } - req.NetworkMonitor = &s.sNetworkMonitor.Checked - req.DisableDns = &s.sDisableDNS.Checked - req.DisableClientRoutes = &s.sDisableClientRoutes.Checked - req.DisableServerRoutes = &s.sDisableServerRoutes.Checked - req.BlockLanAccess = &s.sBlockLANAccess.Checked - - if s.iPreSharedKey.Text != censoredPreSharedKey { - req.OptionalPreSharedKey = &s.iPreSharedKey.Text - } - - conn, err := s.getSrvClient(failFastTimeout) - if err != nil { - log.Errorf("get client: %v", err) - dialog.ShowError(fmt.Errorf("Failed to connect to the service: %v", err), s.wSettings) - return - } - _, err = conn.SetConfig(s.ctx, &req) - if err != nil { - log.Errorf("set config: %v", err) - dialog.ShowError(fmt.Errorf("Failed to set configuration: %v", err), s.wSettings) - return - } - - go func() { - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) - if err != nil { - log.Errorf("get service status: %v", err) - dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) - return - } - if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) - if err != nil { - log.Errorf("down service: %v", err) - } - - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) - return - } - } - }() - } - }, - OnCancel: func() { - s.wSettings.Close() +func (s *serviceClient) getSSHForm() *widget.Form { + return &widget.Form{ + Items: []*widget.FormItem{ + {Text: "Enable SSH Root Login", Widget: s.sEnableSSHRoot}, + {Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP}, + {Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward}, + {Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward}, + {Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth}, + {Text: "JWT Cache TTL (seconds, 0=disabled)", Widget: s.iSSHJWTCacheTTL}, }, } } +func (s *serviceClient) hasSSHChanges() bool { + currentSSHJWTCacheTTL := s.sshJWTCacheTTL + if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" { + val, err := strconv.Atoi(text) + if err != nil { + return true + } + currentSSHJWTCacheTTL = val + } + + return s.enableSSHRoot != s.sEnableSSHRoot.Checked || + s.enableSSHSFTP != s.sEnableSSHSFTP.Checked || + s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked || + s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked || + s.disableSSHAuth != s.sDisableSSHAuth.Checked || + s.sshJWTCacheTTL != currentSSHJWTCacheTTL +} + func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { @@ -1123,6 +1265,25 @@ func (s *serviceClient) getSrvConfig() { s.disableServerRoutes = cfg.DisableServerRoutes s.blockLANAccess = cfg.BlockLANAccess + if cfg.EnableSSHRoot != nil { + s.enableSSHRoot = *cfg.EnableSSHRoot + } + if cfg.EnableSSHSFTP != nil { + s.enableSSHSFTP = *cfg.EnableSSHSFTP + } + if cfg.EnableSSHLocalPortForwarding != nil { + s.enableSSHLocalPortForward = *cfg.EnableSSHLocalPortForwarding + } + if cfg.EnableSSHRemotePortForwarding != nil { + s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding + } + if cfg.DisableSSHAuth != nil { + s.disableSSHAuth = *cfg.DisableSSHAuth + } + if cfg.SSHJWTCacheTTL != nil { + s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL + } + if s.showAdvancedSettings { s.iMngURL.SetText(s.managementURL) s.iPreSharedKey.SetText(cfg.PreSharedKey) @@ -1143,6 +1304,24 @@ func (s *serviceClient) getSrvConfig() { s.sDisableClientRoutes.SetChecked(cfg.DisableClientRoutes) s.sDisableServerRoutes.SetChecked(cfg.DisableServerRoutes) s.sBlockLANAccess.SetChecked(cfg.BlockLANAccess) + if cfg.EnableSSHRoot != nil { + s.sEnableSSHRoot.SetChecked(*cfg.EnableSSHRoot) + } + if cfg.EnableSSHSFTP != nil { + s.sEnableSSHSFTP.SetChecked(*cfg.EnableSSHSFTP) + } + if cfg.EnableSSHLocalPortForwarding != nil { + s.sEnableSSHLocalPortForward.SetChecked(*cfg.EnableSSHLocalPortForwarding) + } + if cfg.EnableSSHRemotePortForwarding != nil { + s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding) + } + if cfg.DisableSSHAuth != nil { + s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth) + } + if cfg.SSHJWTCacheTTL != nil { + s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL)) + } } if s.mNotifications == nil { @@ -1213,6 +1392,15 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { config.DisableServerRoutes = cfg.DisableServerRoutes config.BlockLANAccess = cfg.BlockLanAccess + config.EnableSSHRoot = &cfg.EnableSSHRoot + config.EnableSSHSFTP = &cfg.EnableSSHSFTP + config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding + config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding + config.DisableSSHAuth = &cfg.DisableSSHAuth + + ttl := int(cfg.SshJWTCacheTTL) + config.SSHJWTCacheTTL = &ttl + return &config } diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index d542e2739..4dc14a1ca 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" netbird "github.com/netbirdio/netbird/client/embed" + sshdetection "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" @@ -125,10 +126,15 @@ func createSSHMethod(client *netbird.Client) js.Func { username = args[2].String() } + var jwtToken string + if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { + jwtToken = args[3].String() + } + return createPromise(func(resolve, reject js.Value) { sshClient := ssh.NewClient(client) - if err := sshClient.Connect(host, port, username); err != nil { + if err := sshClient.Connect(host, port, username, jwtToken); err != nil { reject.Invoke(err.Error()) return } @@ -191,12 +197,43 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value { })) } +// createDetectSSHServerMethod creates the SSH server detection method +func createDetectSSHServerMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: requires host and port") + } + + host := args[0].String() + port := args[1].Int() + + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + serverType, err := detectSSHServerType(ctx, client, host, port) + if err != nil { + reject.Invoke(err.Error()) + return + } + + resolve.Invoke(js.ValueOf(serverType.RequiresJWT())) + }) + }) +} + +// detectSSHServerType detects SSH server type using NetBird network connection +func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) { + return sshdetection.DetectSSHServerType(ctx, client, host, port) +} + // createClientObject wraps the NetBird client in a JavaScript object func createClientObject(client *netbird.Client) js.Value { obj := make(map[string]interface{}) obj["start"] = createStartMethod(client) obj["stop"] = createStopMethod(client) + obj["detectSSHServerType"] = createDetectSSHServerMethod(client) obj["createSSHConnection"] = createSSHMethod(client) obj["proxyRequest"] = createProxyRequestMethod(client) obj["createRDPProxy"] = createRDPProxyMethod(client) diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go index ca35525eb..568437e56 100644 --- a/client/wasm/internal/ssh/client.go +++ b/client/wasm/internal/ssh/client.go @@ -13,6 +13,7 @@ import ( "golang.org/x/crypto/ssh" netbird "github.com/netbirdio/netbird/client/embed" + nbssh "github.com/netbirdio/netbird/client/ssh" ) const ( @@ -45,34 +46,19 @@ func NewClient(nbClient *netbird.Client) *Client { } // Connect establishes an SSH connection through NetBird network -func (c *Client) Connect(host string, port int, username string) error { +func (c *Client) Connect(host string, port int, username, jwtToken string) error { addr := fmt.Sprintf("%s:%d", host, port) logrus.Infof("SSH: Connecting to %s as %s", addr, username) - var authMethods []ssh.AuthMethod - - nbConfig, err := c.nbClient.GetConfig() + authMethods, err := c.getAuthMethods(jwtToken) if err != nil { - return fmt.Errorf("get NetBird config: %w", err) + return err } - if nbConfig.SSHKey == "" { - return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization") - } - - signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey)) - if err != nil { - return fmt.Errorf("parse NetBird SSH private key: %w", err) - } - - pubKey := signer.PublicKey() - logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type()) - - authMethods = append(authMethods, ssh.PublicKeys(signer)) config := &ssh.ClientConfig{ User: username, Auth: authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: nbssh.CreateHostKeyCallback(c.nbClient), Timeout: sshDialTimeout, } @@ -96,6 +82,33 @@ func (c *Client) Connect(host string, port int, username string) error { return nil } +// getAuthMethods returns SSH authentication methods, preferring JWT if available +func (c *Client) getAuthMethods(jwtToken string) ([]ssh.AuthMethod, error) { + if jwtToken != "" { + logrus.Debugf("SSH: Using JWT password authentication") + return []ssh.AuthMethod{ssh.Password(jwtToken)}, nil + } + + logrus.Debugf("SSH: No JWT token, using public key authentication") + + nbConfig, err := c.nbClient.GetConfig() + if err != nil { + return nil, fmt.Errorf("get NetBird config: %w", err) + } + + if nbConfig.SSHKey == "" { + return nil, fmt.Errorf("no NetBird SSH key available") + } + + signer, err := ssh.ParsePrivateKey([]byte(nbConfig.SSHKey)) + if err != nil { + return nil, fmt.Errorf("parse NetBird SSH private key: %w", err) + } + + logrus.Debugf("SSH: Added public key auth") + return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil +} + // StartSession starts an SSH session with PTY func (c *Client) StartSession(cols, rows int) error { if c.sshClient == nil { diff --git a/client/wasm/internal/ssh/key.go b/client/wasm/internal/ssh/key.go deleted file mode 100644 index 4868ba30a..000000000 --- a/client/wasm/internal/ssh/key.go +++ /dev/null @@ -1,50 +0,0 @@ -//go:build js - -package ssh - -import ( - "crypto/x509" - "encoding/pem" - "fmt" - "strings" - - "github.com/sirupsen/logrus" - "golang.org/x/crypto/ssh" -) - -// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format -func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) { - keyStr := string(keyPEM) - if !strings.Contains(keyStr, "-----BEGIN") { - keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----") - } - - signer, err := ssh.ParsePrivateKey(keyPEM) - if err == nil { - return signer, nil - } - logrus.Debugf("SSH: Failed to parse as SSH format: %v", err) - - block, _ := pem.Decode(keyPEM) - if block == nil { - keyPreview := string(keyPEM) - if len(keyPreview) > 100 { - keyPreview = keyPreview[:100] - } - return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview) - } - - key, err := x509.ParsePKCS8PrivateKey(block.Bytes) - if err != nil { - logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err) - if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { - return ssh.NewSignerFromKey(rsaKey) - } - if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil { - return ssh.NewSignerFromKey(ecKey) - } - return nil, fmt.Errorf("parse private key: %w", err) - } - - return ssh.NewSignerFromKey(key) -} diff --git a/go.mod b/go.mod index 2d7e0d31c..45a36190d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/netbirdio/netbird -go 1.23.0 +go 1.23.1 require ( cunicu.li/go-rosenpass v0.4.0 @@ -17,8 +17,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.40.0 - golang.org/x/sys v0.34.0 + golang.org/x/crypto v0.41.0 + golang.org/x/sys v0.35.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -31,6 +31,7 @@ require ( fyne.io/fyne/v2 v2.7.0 fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible + github.com/awnumar/memguard v0.23.0 github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2/config v1.29.14 github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 @@ -76,6 +77,7 @@ require ( github.com/pion/stun/v3 v3.0.0 github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 + github.com/pkg/sftp v1.13.9 github.com/prometheus/client_golang v1.22.0 github.com/quic-go/quic-go v0.49.1 github.com/redis/go-redis/v9 v9.7.3 @@ -108,7 +110,7 @@ require ( golang.org/x/net v0.42.0 golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 - golang.org/x/term v0.33.0 + golang.org/x/term v0.34.0 golang.org/x/time v0.12.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 @@ -130,6 +132,7 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect + github.com/awnumar/memcall v0.4.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect @@ -197,6 +200,7 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/kr/fs v0.1.0 // indirect github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.7 // indirect @@ -248,8 +252,8 @@ require ( go.opentelemetry.io/otel/trace v1.35.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.24.0 // indirect - golang.org/x/text v0.27.0 // indirect - golang.org/x/tools v0.34.0 // indirect + golang.org/x/text v0.28.0 // indirect + golang.org/x/tools v0.35.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect diff --git a/go.sum b/go.sum index f4b62dff0..ec68a8f59 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,10 @@ github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJT github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g= +github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w= +github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A= +github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs= @@ -303,6 +307,8 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -432,6 +438,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= +github.com/pkg/sftp v1.13.9 h1:4NGkvGudBL7GteO3m6qnaQ4pC0Kvf0onSVc9gR3EWBw= +github.com/pkg/sftp v1.13.9/go.mod h1:OBN7bVXdstkFFN/gdnHPUb5TE8eb8G1Rp9wCItqjkkA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= @@ -587,9 +595,13 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= @@ -607,6 +619,9 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -628,7 +643,10 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -643,6 +661,10 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -675,19 +697,28 @@ golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= -golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -695,9 +726,12 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -713,8 +747,10 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= -golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 409bdaaba..18a8427be 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -66,7 +66,7 @@ func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) if err != nil { log.Fatalf("failed to create account manager: %v", err) } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 9a4681eae..7f64034df 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -3,6 +3,8 @@ package grpc import ( "context" "fmt" + "net/url" + "strings" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" nbdns "github.com/netbirdio/netbird/dns" @@ -81,12 +83,21 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken return nbConfig } -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, config *nbconfig.Config) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) + + sshConfig := &proto.SSHConfig{ + SshEnabled: peer.SSHEnabled, + } + + if peer.SSHEnabled { + sshConfig.JwtConfig = buildJWTConfig(config) + } + return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network - SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), + SshConfig: sshConfig, Fqdn: fqdn, RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, LazyConnectionEnabled: settings.LazyConnectionEnabled, @@ -95,7 +106,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, config), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), @@ -350,3 +361,51 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe } return protoGroup } + +// buildJWTConfig constructs JWT configuration for SSH servers from management server config +func buildJWTConfig(config *nbconfig.Config) *proto.JWTConfig { + if config == nil { + return nil + } + + if config.HttpConfig == nil || config.HttpConfig.AuthAudience == "" { + return nil + } + + issuer := strings.TrimSpace(config.HttpConfig.AuthIssuer) + if issuer == "" { + if config.DeviceAuthorizationFlow != nil { + if d := deriveIssuerFromTokenEndpoint(config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint); d != "" { + issuer = d + } + } + } + if issuer == "" { + return nil + } + + keysLocation := strings.TrimSpace(config.HttpConfig.AuthKeysLocation) + if keysLocation == "" { + keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json" + } + + return &proto.JWTConfig{ + Issuer: issuer, + Audience: config.HttpConfig.AuthAudience, + KeysLocation: keysLocation, + } +} + +// deriveIssuerFromTokenEndpoint extracts the issuer URL from a token endpoint +func deriveIssuerFromTokenEndpoint(tokenEndpoint string) string { + if tokenEndpoint == "" { + return "" + } + + u, err := url.Parse(tokenEndpoint) + if err != nil { + return "" + } + + return fmt.Sprintf("%s://%s/", u.Scheme, u.Host) +} diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 08a840316..4364272a0 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -646,7 +646,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config), Checks: toProtocolChecks(ctx, postureChecks), } diff --git a/management/server/account.go b/management/server/account.go index a4b2a752b..3e498536c 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -15,6 +15,8 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/shared/auth" + cacheStore "github.com/eko/gocache/lib/v4/store" "github.com/eko/gocache/store/redis/v4" "github.com/rs/xid" @@ -25,6 +27,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -81,6 +84,9 @@ type DefaultAccountManager struct { proxyController port_forwarding.Controller settingsManager settings.Manager + // config contains the management server configuration + config *nbconfig.Config + // singleAccountMode indicates whether the instance has a single account. // If true, then every new user will end up under the same account. // This value will be set to false if management service has more than one account. @@ -171,6 +177,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager( ctx context.Context, + config *nbconfig.Config, store store.Store, networkMapController network_map.Controller, idpManager idp.Manager, @@ -192,6 +199,7 @@ func BuildManager( am := &DefaultAccountManager{ Store: store, + config: config, geo: geo, networkMapController: networkMapController, idpManager: idpManager, @@ -1006,7 +1014,7 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun } // updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes -func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth, +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth auth.UserAuth, primaryDomain bool, ) error { if userAuth.Domain == "" { @@ -1055,7 +1063,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, userAccountID string, domainAccountID string, - userAuth nbcontext.UserAuth, + userAuth auth.UserAuth, ) error { primaryDomain := domainAccountID == "" || userAccountID == domainAccountID err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain) @@ -1074,7 +1082,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) { if userAuth.UserId == "" { return "", fmt.Errorf("user ID is empty") } @@ -1105,7 +1113,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai return newAccount.Id, nil } -func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) { newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID @@ -1217,7 +1225,7 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { - log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err) + log.Errorf("failed to get account onboarding for account %s: %v", accountID, err) return nil, err } @@ -1269,7 +1277,7 @@ func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, ac return newOnboarding, nil } -func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { +func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) } @@ -1313,7 +1321,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. // requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager -func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { +func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error { if userAuth.IsChild || userAuth.IsPAT { return nil } @@ -1471,7 +1479,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // // UserAuth IsChild -> checks that account exists -func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth auth.UserAuth) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory) @@ -1550,7 +1558,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont return domainAccountID, cancel, nil } -func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth auth.UserAuth) (string, error) { userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) @@ -1598,7 +1606,7 @@ func handleNotFound(err error) error { return nil } -func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool { +func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAuth) bool { return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 7c174a481..9b3902d87 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -6,10 +6,11 @@ import ( "net/netip" "time" + "github.com/netbirdio/netbird/shared/auth" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peers/ephemeral" @@ -45,10 +46,10 @@ type Manager interface { GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) - GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error GetUserByID(ctx context.Context, id string) (*types.User, error) - GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error @@ -117,11 +118,11 @@ type Manager interface { UpdateAccountPeers(ctx context.Context, accountID string) BufferUpdateAccountPeers(ctx context.Context, accountID string) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) - SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error + SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error GetStore() store.Store GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) - GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) } diff --git a/management/server/account_test.go b/management/server/account_test.go index ee9950796..10d718bbf 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -28,7 +28,6 @@ import ( nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -45,6 +44,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" ) func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) { @@ -445,7 +445,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { - type initUserParams nbcontext.UserAuth + type initUserParams auth.UserAuth var ( publicDomain = "public.com" @@ -468,7 +468,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { testCases := []struct { name string - inputClaims nbcontext.UserAuth + inputClaims auth.UserAuth inputInitUserParams initUserParams inputUpdateAttrs bool inputUpdateClaimAccount bool @@ -483,7 +483,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }{ { name: "New User With Public Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: publicDomain, UserId: "pub-domain-user", DomainCategory: types.PublicCategory, @@ -500,7 +500,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Unknown Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: unknownDomain, UserId: "unknown-domain-user", DomainCategory: types.UnknownCategory, @@ -517,7 +517,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: privateDomain, UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -534,7 +534,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New Regular User With Existing Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: privateDomain, UserId: "new-pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -552,7 +552,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing User With Existing Reclassified Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -569,7 +569,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -587,7 +587,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "User With Private Category And Empty Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: "", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -616,7 +616,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -656,7 +656,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount Domain: domain, UserId: userId, @@ -915,13 +915,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) { } func BenchmarkTest_GetAccountWithclaims(b *testing.B) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ Domain: "example.com", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, } - publicClaims := nbcontext.UserAuth{ + publicClaims := auth.UserAuth{ Domain: "test.com", UserId: "public-domain-user", DomainCategory: types.PublicCategory, @@ -2709,7 +2709,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") t.Run("skip sync for token auth type", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group3"}, @@ -2724,7 +2724,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("empty jwt groups", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{}, @@ -2738,7 +2738,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("jwt match existing api group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1"}, @@ -2759,7 +2759,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"])) - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1"}, @@ -2777,7 +2777,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add jwt group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1", "group2"}, @@ -2791,7 +2791,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("existed group not update", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group2"}, @@ -2805,7 +2805,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add new group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user2", AccountId: "accountID", Groups: []string{"group1", "group3"}, @@ -2823,7 +2823,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when list is empty", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{}, @@ -2838,7 +2838,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user2", AccountId: "accountID", Groups: []string{}, @@ -2959,7 +2959,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, nil, err } @@ -3692,7 +3692,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { // Test adding new user to existing account with approval required newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: newUserID, Domain: "example.com", DomainCategory: types.PrivateCategory, @@ -3722,7 +3722,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { } // Create a domain-based account without user approval - ownerUserAuth := nbcontext.UserAuth{ + ownerUserAuth := auth.UserAuth{ UserId: "owner-user", Domain: "example.com", DomainCategory: types.PrivateCategory, @@ -3741,7 +3741,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { // Test adding new user to existing account without approval required newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: newUserID, Domain: "example.com", DomainCategory: types.PrivateCategory, diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index ece9dc321..0c62357dc 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -9,18 +9,19 @@ import ( "github.com/golang-jwt/jwt/v5" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/base62" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) var _ Manager = (*manager)(nil) type Manager interface { - ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) - EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) MarkPATUsed(ctx context.Context, tokenID string) error GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) } @@ -55,20 +56,20 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s } } -func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { +func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) { token, err := m.validator.ValidateAndParse(ctx, value) if err != nil { - return nbcontext.UserAuth{}, nil, err + return auth.UserAuth{}, nil, err } userAuth, err := m.extractor.ToUserAuth(token) if err != nil { - return nbcontext.UserAuth{}, nil, err + return auth.UserAuth{}, nil, err } return userAuth, token, err } -func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) { if userAuth.IsChild || userAuth.IsPAT { return userAuth, nil } diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go index 30a7a7161..edf158a49 100644 --- a/management/server/auth/manager_mock.go +++ b/management/server/auth/manager_mock.go @@ -3,9 +3,10 @@ package auth import ( "context" + "github.com/netbirdio/netbird/shared/auth" + "github.com/golang-jwt/jwt/v5" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" ) @@ -15,18 +16,18 @@ var ( // @note really dislike this mocking approach but rather than have to do additional test refactoring. type MockManager struct { - ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) - EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + ValidateAndParseTokenFunc func(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) MarkPATUsedFunc func(ctx context.Context, tokenID string) error GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) } // EnsureUserAccessByJWTGroups implements Manager. -func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) { if m.EnsureUserAccessByJWTGroupsFunc != nil { return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token) } - return nbcontext.UserAuth{}, nil + return auth.UserAuth{}, nil } // GetPATInfo implements Manager. @@ -46,9 +47,9 @@ func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error { } // ValidateAndParseToken implements Manager. -func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { +func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) { if m.ValidateAndParseTokenFunc != nil { return m.ValidateAndParseTokenFunc(ctx, value) } - return nbcontext.UserAuth{}, &jwt.Token{}, nil + return auth.UserAuth{}, &jwt.Token{}, nil } diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go index c8015eb37..b9f091b1e 100644 --- a/management/server/auth/manager_test.go +++ b/management/server/auth/manager_test.go @@ -17,10 +17,10 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/auth" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbauth "github.com/netbirdio/netbird/shared/auth" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) { @@ -131,7 +131,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) { } // this has been validated and parsed by ValidateAndParseToken - userAuth := nbcontext.UserAuth{ + userAuth := nbauth.UserAuth{ AccountId: account.Id, Domain: domain, UserId: userId, @@ -236,7 +236,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tests := []struct { name string tokenFunc func() string - expected *nbcontext.UserAuth // nil indicates expected error + expected *nbauth.UserAuth // nil indicates expected error }{ { name: "Valid with custom claims", @@ -258,7 +258,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tokenString, _ := token.SignedString(key) return tokenString }, - expected: &nbcontext.UserAuth{ + expected: &nbauth.UserAuth{ UserId: "user-id|123", AccountId: "account-id|567", Domain: "http://localhost", @@ -282,7 +282,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tokenString, _ := token.SignedString(key) return tokenString }, - expected: &nbcontext.UserAuth{ + expected: &nbauth.UserAuth{ UserId: "user-id|123", }, }, diff --git a/management/server/context/auth.go b/management/server/context/auth.go index 5cb28ddb7..cc59b8a63 100644 --- a/management/server/context/auth.go +++ b/management/server/context/auth.go @@ -4,7 +4,8 @@ import ( "context" "fmt" "net/http" - "time" + + "github.com/netbirdio/netbird/shared/auth" ) type key int @@ -13,45 +14,22 @@ const ( UserAuthContextKey key = iota ) -type UserAuth struct { - // The account id the user is accessing - AccountId string - // The account domain - Domain string - // The account domain category, TBC values - DomainCategory string - // Indicates whether this user was invited, TBC logic - Invited bool - // Indicates whether this is a child account - IsChild bool - - // The user id - UserId string - // Last login time for this user - LastLogin time.Time - // The Groups the user belongs to on this account - Groups []string - - // Indicates whether this user has authenticated with a Personal Access Token - IsPAT bool -} - -func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) { +func GetUserAuthFromRequest(r *http.Request) (auth.UserAuth, error) { return GetUserAuthFromContext(r.Context()) } -func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request { +func SetUserAuthInRequest(r *http.Request, userAuth auth.UserAuth) *http.Request { return r.WithContext(SetUserAuthInContext(r.Context(), userAuth)) } -func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) { - if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok { +func GetUserAuthFromContext(ctx context.Context) (auth.UserAuth, error) { + if userAuth, ok := ctx.Value(UserAuthContextKey).(auth.UserAuth); ok { return userAuth, nil } - return UserAuth{}, fmt.Errorf("user auth not in context") + return auth.UserAuth{}, fmt.Errorf("user auth not in context") } -func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context { +func SetUserAuthInContext(ctx context.Context, userAuth auth.UserAuth) context.Context { //nolint ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId) //nolint diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 356a2f640..6b7a36c20 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -224,7 +224,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock()) - return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createDNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 4b9b79fdc..c5c48ef32 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -236,7 +237,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: adminUser.Id, AccountId: accountID, Domain: "hotmail.com", diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 08a0b2afd..67638aea5 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/types" ) // dnsSettingsHandler is a handler that returns the DNS settings of the account diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index 42b519c29..a027c067e 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -11,13 +11,14 @@ import ( "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, AccountId: testingDNSSettingsAccount.Id, Domain: testingDNSSettingsAccount.Domain, diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index d49b6c7e0..4716782f3 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -19,6 +19,7 @@ import ( "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", AccountId: testNSGroupAccountID, Domain: "hotmail.com", diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index a0695fa3f..923a24e31 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -14,11 +14,12 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" ) func initEventsTestData(account string, events ...*activity.Event) *handler { @@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_account", diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index e861e873c..208a2e828 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -11,10 +11,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns groups of the account diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 34694ec8c..b7dd3944a 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -19,12 +19,13 @@ import ( "github.com/netbirdio/netbird/management/server" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" ) var TestPeers = map[string]*nbpeer.Peer{ @@ -122,7 +123,7 @@ func TestGetGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -248,7 +249,7 @@ func TestWriteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -330,7 +331,7 @@ func TestDeleteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index d7b598a5d..f99eca794 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -12,15 +12,15 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/shared/management/status" nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" ) // handler is a handler that returns networks of the account diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index 59396dceb..c31729a39 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -8,10 +8,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" ) type resourceHandler struct { diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index 2e64c637f..c311a29fe 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -7,10 +7,10 @@ import ( "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" ) type routersHandler struct { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 7a5a6d911..ddf2e2a70 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -21,6 +21,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/stretchr/testify/assert" @@ -296,7 +297,7 @@ func TestGetPeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "admin_user", Domain: "hotmail.com", AccountId: "test_id", @@ -444,7 +445,7 @@ func TestGetAccessiblePeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: tc.callerUserID, Domain: "hotmail.com", AccountId: "test_id", @@ -527,7 +528,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody))) req.Header.Set("Content-Type", "application/json") - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: tc.callerUserID, Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index cedd5ac88..094a36e38 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -16,12 +16,13 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/util" ) @@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index cb6995793..a2d656a47 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -9,11 +9,11 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index 4d6bad5e3..ab1639ab1 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -10,10 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns policy of the account diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index fd39ae2a3..ca5a0a6ab 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -14,10 +14,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) func initPoliciesTestData(policies ...*types.Policy) *handler { @@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 3ebc4d1e1..744cde10b 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index c644b533a..8c60d6fe8 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -16,9 +16,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index 466a7987f..a44d81e3e 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" @@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: testAccountID, diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 2287dadfe..d267b6eea 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -10,10 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns a list of setup keys of the account diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 7b46b486b..b137b6dd1 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -15,10 +15,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: adminUser.Id, Domain: "hotmail.com", AccountId: "testAccountId", diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index bae07af4a..867db3ca9 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -8,10 +8,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // patHandler is the nameserver group handler of the account diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index 92544c56d..7cda14468 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -17,10 +17,11 @@ import ( "github.com/netbirdio/netbird/management/server/util" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index e08004218..37f0a6c1d 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -128,7 +129,7 @@ func initUsersTestData() *handler { return nil }, - GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { + GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { switch userAuth.UserId { case "not-found": return nil, status.NewUserNotFoundError("not-found") @@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) { tt := []struct { name string expectedStatus int - requestAuth nbcontext.UserAuth + requestAuth auth.UserAuth expectedResult *api.User }{ { @@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) { }, { name: "user not found", - requestAuth: nbcontext.UserAuth{UserId: "not-found"}, + requestAuth: auth.UserAuth{UserId: "not-found"}, expectedStatus: http.StatusNotFound, }, { name: "not of account", - requestAuth: nbcontext.UserAuth{UserId: "not-of-account"}, + requestAuth: auth.UserAuth{UserId: "not-of-account"}, expectedStatus: http.StatusForbidden, }, { name: "blocked user", - requestAuth: nbcontext.UserAuth{UserId: "blocked-user"}, + requestAuth: auth.UserAuth{UserId: "blocked-user"}, expectedStatus: http.StatusForbidden, }, { name: "service user", - requestAuth: nbcontext.UserAuth{UserId: "service-user"}, + requestAuth: auth.UserAuth{UserId: "service-user"}, expectedStatus: http.StatusForbidden, }, { name: "owner", - requestAuth: nbcontext.UserAuth{UserId: "owner"}, + requestAuth: auth.UserAuth{UserId: "owner"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "owner", @@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "regular user", - requestAuth: nbcontext.UserAuth{UserId: "regular-user"}, + requestAuth: auth.UserAuth{UserId: "regular-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "regular-user", @@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "admin user", - requestAuth: nbcontext.UserAuth{UserId: "admin-user"}, + requestAuth: auth.UserAuth{UserId: "admin-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "admin-user", @@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "restricted user", - requestAuth: nbcontext.UserAuth{UserId: "restricted-user"}, + requestAuth: auth.UserAuth{UserId: "restricted-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "restricted-user", @@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) { req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) require.NoError(t, err) - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ AccountId: existingAccountID, UserId: tc.requestingUser.Id, } @@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) { req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) require.NoError(t, err) - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ AccountId: existingAccountID, UserId: tc.requestingUser.Id, } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index bce917a25..9439165a4 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -10,22 +10,23 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/auth" + serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) -type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) -type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error +type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error) +type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error -type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) +type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - authManager auth.Manager + authManager serverauth.Manager ensureAccount EnsureAccountFunc getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc @@ -34,7 +35,7 @@ type AuthMiddleware struct { // NewAuthMiddleware instance constructor func NewAuthMiddleware( - authManager auth.Manager, + authManager serverauth.Manager, ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, @@ -61,18 +62,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return } - auth := strings.Split(r.Header.Get("Authorization"), " ") - authType := strings.ToLower(auth[0]) + authHeader := strings.Split(r.Header.Get("Authorization"), " ") + authType := strings.ToLower(authHeader[0]) // fallback to token when receive pat as bearer - if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") { + if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") { authType = "token" - auth[0] = authType + authHeader[0] = authType } switch authType { case "bearer": - request, err := m.checkJWTFromRequest(r, auth) + request, err := m.checkJWTFromRequest(r, authHeader) if err != nil { log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) @@ -81,7 +82,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { h.ServeHTTP(w, request) case "token": - request, err := m.checkPATFromRequest(r, auth) + request, err := m.checkPATFromRequest(r, authHeader) if err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) // Check if it's a status error, otherwise default to Unauthorized @@ -100,8 +101,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) { - token, err := getTokenFromJWTRequest(auth) +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { + token, err := getTokenFromJWTRequest(authHeaderParts) // If an error occurs, call the error handler and return an error if err != nil { @@ -151,8 +152,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) { - token, err := getTokenFromPATRequest(auth) +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { + token, err := getTokenFromPATRequest(authHeaderParts) if err != nil { return r, fmt.Errorf("error extracting token: %w", err) } @@ -177,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h return r, err } - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: user.Id, AccountId: user.AccountID, Domain: accDomain, diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index d1bd9959f..7badc03e4 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -12,11 +12,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/auth" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" + nbauth "github.com/netbirdio/netbird/shared/auth" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) const ( @@ -75,9 +76,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use return nil, nil, "", "", fmt.Errorf("PAT invalid") } -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { +func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) { if token == JWT { - return nbcontext.UserAuth{ + return nbauth.UserAuth{ UserId: userID, AccountId: accountID, Domain: testAccount.Domain, @@ -91,7 +92,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA Valid: true, }, nil } - return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid") + return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid") } func mockMarkPATUsed(_ context.Context, token string) error { @@ -101,7 +102,7 @@ func mockMarkPATUsed(_ context.Context, token string) error { return fmt.Errorf("Should never get reached") } -func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) { if userAuth.IsChild || userAuth.IsPAT { return userAuth, nil } @@ -197,13 +198,13 @@ func TestAuthMiddleware_Handler(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, nil, @@ -255,13 +256,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -306,13 +307,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -348,13 +349,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -391,13 +392,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -454,13 +455,13 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, rateLimitConfig, @@ -508,13 +509,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name string path string authHeader string - expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status + expectedUserAuth *nbauth.UserAuth // nil expects 401 response status }{ { name: "Valid PAT Token", path: "/test", authHeader: "Token " + PAT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: accountID, UserId: userID, Domain: testAccount.Domain, @@ -526,7 +527,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid PAT Token accesses child", path: "/test?account=xyz", authHeader: "Token " + PAT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, @@ -539,7 +540,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid JWT Token", path: "/test", authHeader: "Bearer " + JWT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: accountID, UserId: userID, Domain: testAccount.Domain, @@ -551,7 +552,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid JWT Token with child", path: "/test?account=xyz", authHeader: "Bearer " + JWT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, @@ -570,13 +571,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, nil, diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index ab3f5437a..ac165aeb2 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" @@ -18,8 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/auth" - nbcontext "github.com/netbirdio/netbird/management/server/context" + serverauth "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" http2 "github.com/netbirdio/netbird/management/server/http" @@ -33,6 +33,7 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" ) func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { @@ -71,14 +72,14 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee ctx := context.Background() requestBuffer := server.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock()) - am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) } // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := auth.NewManager(store, "", "", "", "", []string{}, false) - authManagerMock := &auth.MockManager{ + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManagerMock := &serverauth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, MarkPATUsedFunc: authManager.MarkPATUsed, @@ -123,8 +124,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_m } } -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { - userAuth := nbcontext.UserAuth{} +func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) { + userAuth := auth.UserAuth{} switch token { case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go index 49075a0d3..126a76919 100644 --- a/management/server/idp/pocketid_test.go +++ b/management/server/idp/pocketid_test.go @@ -1,138 +1,137 @@ package idp import ( - "context" - "testing" + "context" + "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/telemetry" ) - func TestNewPocketIdManager(t *testing.T) { - type test struct { - name string - inputConfig PocketIdClientConfig - assertErrFunc require.ErrorAssertionFunc - assertErrFuncMessage string - } + type test struct { + name string + inputConfig PocketIdClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } - defaultTestConfig := PocketIdClientConfig{ - APIToken: "api_token", - ManagementEndpoint: "http://localhost", - } + defaultTestConfig := PocketIdClientConfig{ + APIToken: "api_token", + ManagementEndpoint: "http://localhost", + } - tests := []test{ - { - name: "Good Configuration", - inputConfig: defaultTestConfig, - assertErrFunc: require.NoError, - assertErrFuncMessage: "shouldn't return error", - }, - { - name: "Missing ManagementEndpoint", - inputConfig: PocketIdClientConfig{ - APIToken: defaultTestConfig.APIToken, - ManagementEndpoint: "", - }, - assertErrFunc: require.Error, - assertErrFuncMessage: "should return error when field empty", - }, - { - name: "Missing APIToken", - inputConfig: PocketIdClientConfig{ - APIToken: "", - ManagementEndpoint: defaultTestConfig.ManagementEndpoint, - }, - assertErrFunc: require.Error, - assertErrFuncMessage: "should return error when field empty", - }, - } + tests := []test{ + { + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + }, + { + name: "Missing ManagementEndpoint", + inputConfig: PocketIdClientConfig{ + APIToken: defaultTestConfig.APIToken, + ManagementEndpoint: "", + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + { + name: "Missing APIToken", + inputConfig: PocketIdClientConfig{ + APIToken: "", + ManagementEndpoint: defaultTestConfig.ManagementEndpoint, + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) - tc.assertErrFunc(t, err, tc.assertErrFuncMessage) - }) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) + tc.assertErrFunc(t, err, tc.assertErrFuncMessage) + }) + } } func TestPocketID_GetUserDataByID(t *testing.T) { - client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} + client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - md := AppMetadata{WTAccountID: "acc1"} - got, err := mgr.GetUserDataByID(context.Background(), "u1", md) - require.NoError(t, err) - assert.Equal(t, "u1", got.ID) - assert.Equal(t, "user1@example.com", got.Email) - assert.Equal(t, "User One", got.Name) - assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) + md := AppMetadata{WTAccountID: "acc1"} + got, err := mgr.GetUserDataByID(context.Background(), "u1", md) + require.NoError(t, err) + assert.Equal(t, "u1", got.ID) + assert.Equal(t, "user1@example.com", got.Email) + assert.Equal(t, "User One", got.Name) + assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) } func TestPocketID_GetAccount_WithPagination(t *testing.T) { - // Single page response with two users - client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + // Single page response with two users + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - users, err := mgr.GetAccount(context.Background(), "accX") - require.NoError(t, err) - require.Len(t, users, 2) - assert.Equal(t, "u1", users[0].ID) - assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) - assert.Equal(t, "u2", users[1].ID) + users, err := mgr.GetAccount(context.Background(), "accX") + require.NoError(t, err) + require.Len(t, users, 2) + assert.Equal(t, "u1", users[0].ID) + assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "u2", users[1].ID) } func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) { - client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - accounts, err := mgr.GetAllAccounts(context.Background()) - require.NoError(t, err) - require.Len(t, accounts[UnsetAccountID], 2) + accounts, err := mgr.GetAllAccounts(context.Background()) + require.NoError(t, err) + require.Len(t, accounts[UnsetAccountID], 2) } func TestPocketID_CreateUser(t *testing.T) { - client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} + client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") - require.NoError(t, err) - assert.Equal(t, "newid", ud.ID) - assert.Equal(t, "new@example.com", ud.Email) - assert.Equal(t, "New User", ud.Name) - assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) - if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { - assert.True(t, *ud.AppMetadata.WTPendingInvite) - } - assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) + ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") + require.NoError(t, err) + assert.Equal(t, "newid", ud.ID) + assert.Equal(t, "new@example.com", ud.Email) + assert.Equal(t, "New User", ud.Name) + assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) + if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { + assert.True(t, *ud.AppMetadata.WTPendingInvite) + } + assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) } func TestPocketID_InviteAndDeleteUser(t *testing.T) { - // Same mock for both calls; returns OK with empty JSON - client := &mockHTTPClient{code: 200, resBody: `{}`} + // Same mock for both calls; returns OK with empty JSON + client := &mockHTTPClient{code: 200, resBody: `{}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - err = mgr.InviteUserByID(context.Background(), "u1") - require.NoError(t, err) + err = mgr.InviteUserByID(context.Background(), "u1") + require.NoError(t, err) - err = mgr.DeleteUser(context.Background(), "u1") - require.NoError(t, err) + err = mgr.DeleteUser(context.Background(), "u1") + require.NoError(t, err) } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index fc67e01af..496be9caa 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -364,7 +364,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - accountManager, err := BuildManager(ctx, store, networkMapController, nil, "", + accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { diff --git a/management/server/management_test.go b/management/server/management_test.go index 930ecfb5a..c485f16b4 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -209,6 +209,7 @@ func startServer( accountManager, err := server.BuildManager( context.Background(), + nil, str, networkMapController, nil, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 781d84f5f..0178e51f5 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -2,6 +2,7 @@ package mock_server import ( "context" + "github.com/netbirdio/netbird/shared/auth" "net" "net/netip" "time" @@ -12,7 +13,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peers/ephemeral" @@ -34,7 +34,7 @@ type MockAccountManager struct { GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error @@ -84,7 +84,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) - GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error) DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func(settings *types.Settings) string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) @@ -119,7 +119,7 @@ type MockAccountManager struct { GetStoreFunc func() store.Store UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) - GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetCurrentUserInfoFunc func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) @@ -470,7 +470,7 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { +func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) { if am.GetUserFromUserAuthFunc != nil { return am.GetUserFromUserAuthFunc(ctx, userAuth) } @@ -675,7 +675,7 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } -func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { +func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if am.GetAccountIDFromUserAuthFunc != nil { return am.GetAccountIDFromUserAuthFunc(ctx, userAuth) } @@ -937,7 +937,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented") } -func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { +func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error { return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented") } @@ -969,7 +969,7 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented") } -func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { +func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { if am.GetCurrentUserInfoFunc != nil { return am.GetCurrentUserInfoFunc(ctx, userAuth) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 35291b30c..51738c106 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -793,7 +793,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index c6cec6f7e..e2dea2c6b 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -10,8 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 7874be858..6b8cf9412 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -8,11 +8,11 @@ import ( "github.com/rs/xid" - nbDomain "github.com/netbirdio/netbird/shared/management/domain" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" + nbDomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go index 8054d05c6..6be90baa7 100644 --- a/management/server/networks/routers/manager_test.go +++ b/management/server/networks/routers/manager_test.go @@ -9,8 +9,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index 72b15fd9a..e90c61a97 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -5,8 +5,8 @@ import ( "github.com/rs/xid" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/shared/management/http/api" ) type NetworkRouter struct { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 95c609595..21a6952a9 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1292,7 +1292,7 @@ func Test_RegisterPeerByUser(t *testing.T) { requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1377,7 +1377,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1530,7 +1530,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1610,7 +1610,7 @@ func Test_LoginPeer(t *testing.T) { requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index d65dc5045..f0bbbc32e 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,8 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/netbirdio/netbird/shared/management/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/route_test.go b/management/server/route_test.go index 27fe033c8..7ff362bc6 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1292,7 +1292,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - am, err := BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, nil, err } diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go index 6eb391cb5..da29e1d87 100644 --- a/management/server/types/route_firewall_rule.go +++ b/management/server/types/route_firewall_rule.go @@ -1,8 +1,8 @@ package types import ( - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // RouteFirewallRule a firewall rule applicable for a routed network. diff --git a/management/server/user.go b/management/server/user.go index be4e491a8..6b8bcbcad 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -7,12 +7,13 @@ import ( "strings" "time" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" - nbContext "github.com/netbirdio/netbird/management/server/context" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -175,9 +176,9 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id) } -// GetUser looks up a user by provided nbContext.UserAuths. +// GetUser looks up a user by provided auth.UserAuths. // Expects account to have been created already. -func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { +func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { return nil, err @@ -970,7 +971,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou var peerIDs []string for _, peer := range peers { // nolint:staticcheck - ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key) + ctx = context.WithValue(ctx, nbcontext.PeerIDKey, peer.Key) if peer.UserID == "" { // we do not want to expire peers that are added via setup key @@ -1214,7 +1215,7 @@ func validateUserInvite(invite *types.UserInfo) error { } // GetCurrentUserInfo retrieves the account's current user info and permissions -func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { +func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { accountID, userID := userAuth.AccountId, userAuth.UserId user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) diff --git a/management/server/user_test.go b/management/server/user_test.go index 69b8c85ee..5ce15621e 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -11,12 +11,12 @@ import ( "golang.org/x/exp/maps" nbcache "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -966,7 +966,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { permissionsManager: permissionsManager, } - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: mockUserID, AccountId: mockAccountID, } @@ -1573,33 +1573,33 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { tt := []struct { name string - userAuth nbcontext.UserAuth + userAuth auth.UserAuth expectedErr error expectedResult *users.UserInfoWithPermissions }{ { name: "not found", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"}, expectedErr: status.NewUserNotFoundError("not-found"), }, { name: "not part of account", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, expectedErr: status.NewUserNotPartOfAccountError(), }, { name: "blocked", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, expectedErr: status.NewUserBlockedError(), }, { name: "service user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "service-user"}, expectedErr: status.NewPermissionDeniedError(), }, { name: "owner user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "account1Owner", @@ -1619,7 +1619,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "regular user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "regular-user", @@ -1638,7 +1638,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "admin user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "admin-user", @@ -1657,7 +1657,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "settings blocked regular user", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "settings-blocked-user", @@ -1678,7 +1678,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { { name: "settings blocked regular user child account", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "settings-blocked-user", @@ -1698,7 +1698,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "settings blocked owner user", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "account2Owner", diff --git a/relay/server/peer.go b/relay/server/peer.go index c47f2e960..c5ff41857 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -9,10 +9,10 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/relay/healthcheck" - "github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" + "github.com/netbirdio/netbird/shared/relay/healthcheck" + "github.com/netbirdio/netbird/shared/relay/messages" ) const ( diff --git a/management/server/auth/jwt/extractor.go b/shared/auth/jwt/extractor.go similarity index 92% rename from management/server/auth/jwt/extractor.go rename to shared/auth/jwt/extractor.go index d270d0ff1..a41d5f07a 100644 --- a/management/server/auth/jwt/extractor.go +++ b/shared/auth/jwt/extractor.go @@ -8,7 +8,7 @@ import ( "github.com/golang-jwt/jwt/v5" log "github.com/sirupsen/logrus" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" ) const ( @@ -87,9 +87,10 @@ func (c ClaimsExtractor) audienceClaim(claimName string) string { return url } -func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) { +// ToUserAuth extracts user authentication information from a JWT token +func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) { claims := token.Claims.(jwt.MapClaims) - userAuth := nbcontext.UserAuth{} + userAuth := auth.UserAuth{} userID, ok := claims[c.userIDClaim].(string) if !ok { @@ -122,6 +123,7 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, erro return userAuth, nil } +// ToGroups extracts group information from a JWT token func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string { claims := token.Claims.(jwt.MapClaims) userJWTGroups := make([]string, 0) diff --git a/management/server/auth/jwt/validator.go b/shared/auth/jwt/validator.go similarity index 100% rename from management/server/auth/jwt/validator.go rename to shared/auth/jwt/validator.go diff --git a/shared/auth/user.go b/shared/auth/user.go new file mode 100644 index 000000000..c1bae808e --- /dev/null +++ b/shared/auth/user.go @@ -0,0 +1,28 @@ +package auth + +import ( + "time" +) + +type UserAuth struct { + // The account id the user is accessing + AccountId string + // The account domain + Domain string + // The account domain category, TBC values + DomainCategory string + // Indicates whether this user was invited, TBC logic + Invited bool + // Indicates whether this is a child account + IsChild bool + + // The user id + UserId string + // Last login time for this user + LastLogin time.Time + // The Groups the user belongs to on this account + Groups []string + + // Indicates whether this user has authenticated with a Personal Access Token + IsPAT bool +} diff --git a/shared/context/keys.go b/shared/context/keys.go index 5345ee214..c5b5da044 100644 --- a/shared/context/keys.go +++ b/shared/context/keys.go @@ -5,4 +5,4 @@ const ( AccountIDKey = "accountID" UserIDKey = "userID" PeerIDKey = "peerID" -) \ No newline at end of file +) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d3f341529..f98e76ce7 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -117,7 +118,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) - accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } diff --git a/shared/management/operations/operation.go b/shared/management/operations/operation.go index b9b500362..b1ba12815 100644 --- a/shared/management/operations/operation.go +++ b/shared/management/operations/operation.go @@ -1,4 +1,4 @@ package operations // Operation represents a permission operation type -type Operation string \ No newline at end of file +type Operation string diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 0de00ec0c..ca12cf48c 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,19 +1,18 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.32.0 +// protoc v6.32.1 // source: management.proto package proto import ( - reflect "reflect" - sync "sync" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" durationpb "google.golang.org/protobuf/types/known/durationpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" ) const ( @@ -268,7 +267,7 @@ func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { // Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23, 0} + return file_management_proto_rawDescGZIP(), []int{24, 0} } type EncryptedMessage struct { @@ -799,16 +798,21 @@ type Flags struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - RosenpassEnabled bool `protobuf:"varint,1,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - RosenpassPermissive bool `protobuf:"varint,2,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - ServerSSHAllowed bool `protobuf:"varint,3,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` - DisableClientRoutes bool `protobuf:"varint,4,opt,name=disableClientRoutes,proto3" json:"disableClientRoutes,omitempty"` - DisableServerRoutes bool `protobuf:"varint,5,opt,name=disableServerRoutes,proto3" json:"disableServerRoutes,omitempty"` - DisableDNS bool `protobuf:"varint,6,opt,name=disableDNS,proto3" json:"disableDNS,omitempty"` - DisableFirewall bool `protobuf:"varint,7,opt,name=disableFirewall,proto3" json:"disableFirewall,omitempty"` - BlockLANAccess bool `protobuf:"varint,8,opt,name=blockLANAccess,proto3" json:"blockLANAccess,omitempty"` - BlockInbound bool `protobuf:"varint,9,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` - LazyConnectionEnabled bool `protobuf:"varint,10,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + RosenpassEnabled bool `protobuf:"varint,1,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` + RosenpassPermissive bool `protobuf:"varint,2,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` + ServerSSHAllowed bool `protobuf:"varint,3,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` + DisableClientRoutes bool `protobuf:"varint,4,opt,name=disableClientRoutes,proto3" json:"disableClientRoutes,omitempty"` + DisableServerRoutes bool `protobuf:"varint,5,opt,name=disableServerRoutes,proto3" json:"disableServerRoutes,omitempty"` + DisableDNS bool `protobuf:"varint,6,opt,name=disableDNS,proto3" json:"disableDNS,omitempty"` + DisableFirewall bool `protobuf:"varint,7,opt,name=disableFirewall,proto3" json:"disableFirewall,omitempty"` + BlockLANAccess bool `protobuf:"varint,8,opt,name=blockLANAccess,proto3" json:"blockLANAccess,omitempty"` + BlockInbound bool `protobuf:"varint,9,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"` + LazyConnectionEnabled bool `protobuf:"varint,10,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"` + EnableSSHRoot bool `protobuf:"varint,11,opt,name=enableSSHRoot,proto3" json:"enableSSHRoot,omitempty"` + EnableSSHSFTP bool `protobuf:"varint,12,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"` + EnableSSHLocalPortForwarding bool `protobuf:"varint,13,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"` + EnableSSHRemotePortForwarding bool `protobuf:"varint,14,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` + DisableSSHAuth bool `protobuf:"varint,15,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` } func (x *Flags) Reset() { @@ -913,6 +917,41 @@ func (x *Flags) GetLazyConnectionEnabled() bool { return false } +func (x *Flags) GetEnableSSHRoot() bool { + if x != nil { + return x.EnableSSHRoot + } + return false +} + +func (x *Flags) GetEnableSSHSFTP() bool { + if x != nil { + return x.EnableSSHSFTP + } + return false +} + +func (x *Flags) GetEnableSSHLocalPortForwarding() bool { + if x != nil { + return x.EnableSSHLocalPortForwarding + } + return false +} + +func (x *Flags) GetEnableSSHRemotePortForwarding() bool { + if x != nil { + return x.EnableSSHRemotePortForwarding + } + return false +} + +func (x *Flags) GetDisableSSHAuth() bool { + if x != nil { + return x.DisableSSHAuth + } + return false +} + // PeerSystemMeta is machine meta data like OS and version. type PeerSystemMeta struct { state protoimpl.MessageState @@ -1273,6 +1312,7 @@ type NetbirdConfig struct { Signal *HostConfig `protobuf:"bytes,3,opt,name=signal,proto3" json:"signal,omitempty"` Relay *RelayConfig `protobuf:"bytes,4,opt,name=relay,proto3" json:"relay,omitempty"` Flow *FlowConfig `protobuf:"bytes,5,opt,name=flow,proto3" json:"flow,omitempty"` + Jwt *JWTConfig `protobuf:"bytes,6,opt,name=jwt,proto3" json:"jwt,omitempty"` } func (x *NetbirdConfig) Reset() { @@ -1342,6 +1382,13 @@ func (x *NetbirdConfig) GetFlow() *FlowConfig { return nil } +func (x *NetbirdConfig) GetJwt() *JWTConfig { + if x != nil { + return x.Jwt + } + return nil +} + // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) type HostConfig struct { state protoimpl.MessageState @@ -1568,6 +1615,78 @@ func (x *FlowConfig) GetDnsCollection() bool { return false } +// JWTConfig represents JWT authentication configuration +type JWTConfig struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Issuer string `protobuf:"bytes,1,opt,name=issuer,proto3" json:"issuer,omitempty"` + Audience string `protobuf:"bytes,2,opt,name=audience,proto3" json:"audience,omitempty"` + KeysLocation string `protobuf:"bytes,3,opt,name=keysLocation,proto3" json:"keysLocation,omitempty"` + MaxTokenAge int64 `protobuf:"varint,4,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"` +} + +func (x *JWTConfig) Reset() { + *x = JWTConfig{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *JWTConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*JWTConfig) ProtoMessage() {} + +func (x *JWTConfig) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[17] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use JWTConfig.ProtoReflect.Descriptor instead. +func (*JWTConfig) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{17} +} + +func (x *JWTConfig) GetIssuer() string { + if x != nil { + return x.Issuer + } + return "" +} + +func (x *JWTConfig) GetAudience() string { + if x != nil { + return x.Audience + } + return "" +} + +func (x *JWTConfig) GetKeysLocation() string { + if x != nil { + return x.KeysLocation + } + return "" +} + +func (x *JWTConfig) GetMaxTokenAge() int64 { + if x != nil { + return x.MaxTokenAge + } + return 0 +} + // ProtectedHostConfig is similar to HostConfig but has additional user and password // Mostly used for TURN servers type ProtectedHostConfig struct { @@ -1583,7 +1702,7 @@ type ProtectedHostConfig struct { func (x *ProtectedHostConfig) Reset() { *x = ProtectedHostConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[17] + mi := &file_management_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1596,7 +1715,7 @@ func (x *ProtectedHostConfig) String() string { func (*ProtectedHostConfig) ProtoMessage() {} func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[17] + mi := &file_management_proto_msgTypes[18] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1609,7 +1728,7 @@ func (x *ProtectedHostConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProtectedHostConfig.ProtoReflect.Descriptor instead. func (*ProtectedHostConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{17} + return file_management_proto_rawDescGZIP(), []int{18} } func (x *ProtectedHostConfig) GetHostConfig() *HostConfig { @@ -1656,7 +1775,7 @@ type PeerConfig struct { func (x *PeerConfig) Reset() { *x = PeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[18] + mi := &file_management_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1669,7 +1788,7 @@ func (x *PeerConfig) String() string { func (*PeerConfig) ProtoMessage() {} func (x *PeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[18] + mi := &file_management_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1682,7 +1801,7 @@ func (x *PeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerConfig.ProtoReflect.Descriptor instead. func (*PeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{18} + return file_management_proto_rawDescGZIP(), []int{19} } func (x *PeerConfig) GetAddress() string { @@ -1770,7 +1889,7 @@ type NetworkMap struct { func (x *NetworkMap) Reset() { *x = NetworkMap{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[19] + mi := &file_management_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1783,7 +1902,7 @@ func (x *NetworkMap) String() string { func (*NetworkMap) ProtoMessage() {} func (x *NetworkMap) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[19] + mi := &file_management_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1796,7 +1915,7 @@ func (x *NetworkMap) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkMap.ProtoReflect.Descriptor instead. func (*NetworkMap) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{19} + return file_management_proto_rawDescGZIP(), []int{20} } func (x *NetworkMap) GetSerial() uint64 { @@ -1904,7 +2023,7 @@ type RemotePeerConfig struct { func (x *RemotePeerConfig) Reset() { *x = RemotePeerConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1917,7 +2036,7 @@ func (x *RemotePeerConfig) String() string { func (*RemotePeerConfig) ProtoMessage() {} func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[20] + mi := &file_management_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1930,7 +2049,7 @@ func (x *RemotePeerConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use RemotePeerConfig.ProtoReflect.Descriptor instead. func (*RemotePeerConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{20} + return file_management_proto_rawDescGZIP(), []int{21} } func (x *RemotePeerConfig) GetWgPubKey() string { @@ -1978,13 +2097,14 @@ type SSHConfig struct { SshEnabled bool `protobuf:"varint,1,opt,name=sshEnabled,proto3" json:"sshEnabled,omitempty"` // sshPubKey is a SSH public key of a peer to be added to authorized_hosts. // This property should be ignore if SSHConfig comes from PeerConfig. - SshPubKey []byte `protobuf:"bytes,2,opt,name=sshPubKey,proto3" json:"sshPubKey,omitempty"` + SshPubKey []byte `protobuf:"bytes,2,opt,name=sshPubKey,proto3" json:"sshPubKey,omitempty"` + JwtConfig *JWTConfig `protobuf:"bytes,3,opt,name=jwtConfig,proto3" json:"jwtConfig,omitempty"` } func (x *SSHConfig) Reset() { *x = SSHConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1997,7 +2117,7 @@ func (x *SSHConfig) String() string { func (*SSHConfig) ProtoMessage() {} func (x *SSHConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[21] + mi := &file_management_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2010,7 +2130,7 @@ func (x *SSHConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHConfig.ProtoReflect.Descriptor instead. func (*SSHConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{21} + return file_management_proto_rawDescGZIP(), []int{22} } func (x *SSHConfig) GetSshEnabled() bool { @@ -2027,6 +2147,13 @@ func (x *SSHConfig) GetSshPubKey() []byte { return nil } +func (x *SSHConfig) GetJwtConfig() *JWTConfig { + if x != nil { + return x.JwtConfig + } + return nil +} + // DeviceAuthorizationFlowRequest empty struct for future expansion type DeviceAuthorizationFlowRequest struct { state protoimpl.MessageState @@ -2037,7 +2164,7 @@ type DeviceAuthorizationFlowRequest struct { func (x *DeviceAuthorizationFlowRequest) Reset() { *x = DeviceAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2050,7 +2177,7 @@ func (x *DeviceAuthorizationFlowRequest) String() string { func (*DeviceAuthorizationFlowRequest) ProtoMessage() {} func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[22] + mi := &file_management_proto_msgTypes[23] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2063,7 +2190,7 @@ func (x *DeviceAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{22} + return file_management_proto_rawDescGZIP(), []int{23} } // DeviceAuthorizationFlow represents Device Authorization Flow information @@ -2082,7 +2209,7 @@ type DeviceAuthorizationFlow struct { func (x *DeviceAuthorizationFlow) Reset() { *x = DeviceAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2095,7 +2222,7 @@ func (x *DeviceAuthorizationFlow) String() string { func (*DeviceAuthorizationFlow) ProtoMessage() {} func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[23] + mi := &file_management_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2108,7 +2235,7 @@ func (x *DeviceAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use DeviceAuthorizationFlow.ProtoReflect.Descriptor instead. func (*DeviceAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{23} + return file_management_proto_rawDescGZIP(), []int{24} } func (x *DeviceAuthorizationFlow) GetProvider() DeviceAuthorizationFlowProvider { @@ -2135,7 +2262,7 @@ type PKCEAuthorizationFlowRequest struct { func (x *PKCEAuthorizationFlowRequest) Reset() { *x = PKCEAuthorizationFlowRequest{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2148,7 +2275,7 @@ func (x *PKCEAuthorizationFlowRequest) String() string { func (*PKCEAuthorizationFlowRequest) ProtoMessage() {} func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[24] + mi := &file_management_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2161,7 +2288,7 @@ func (x *PKCEAuthorizationFlowRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlowRequest.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlowRequest) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{24} + return file_management_proto_rawDescGZIP(), []int{25} } // PKCEAuthorizationFlow represents Authorization Code Flow information @@ -2178,7 +2305,7 @@ type PKCEAuthorizationFlow struct { func (x *PKCEAuthorizationFlow) Reset() { *x = PKCEAuthorizationFlow{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2191,7 +2318,7 @@ func (x *PKCEAuthorizationFlow) String() string { func (*PKCEAuthorizationFlow) ProtoMessage() {} func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[25] + mi := &file_management_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2204,7 +2331,7 @@ func (x *PKCEAuthorizationFlow) ProtoReflect() protoreflect.Message { // Deprecated: Use PKCEAuthorizationFlow.ProtoReflect.Descriptor instead. func (*PKCEAuthorizationFlow) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{25} + return file_management_proto_rawDescGZIP(), []int{26} } func (x *PKCEAuthorizationFlow) GetProviderConfig() *ProviderConfig { @@ -2250,7 +2377,7 @@ type ProviderConfig struct { func (x *ProviderConfig) Reset() { *x = ProviderConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2263,7 +2390,7 @@ func (x *ProviderConfig) String() string { func (*ProviderConfig) ProtoMessage() {} func (x *ProviderConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[26] + mi := &file_management_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2276,7 +2403,7 @@ func (x *ProviderConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use ProviderConfig.ProtoReflect.Descriptor instead. func (*ProviderConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{26} + return file_management_proto_rawDescGZIP(), []int{27} } func (x *ProviderConfig) GetClientID() string { @@ -2384,7 +2511,7 @@ type Route struct { func (x *Route) Reset() { *x = Route{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2397,7 +2524,7 @@ func (x *Route) String() string { func (*Route) ProtoMessage() {} func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[27] + mi := &file_management_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2410,7 +2537,7 @@ func (x *Route) ProtoReflect() protoreflect.Message { // Deprecated: Use Route.ProtoReflect.Descriptor instead. func (*Route) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{27} + return file_management_proto_rawDescGZIP(), []int{28} } func (x *Route) GetID() string { @@ -2498,7 +2625,7 @@ type DNSConfig struct { func (x *DNSConfig) Reset() { *x = DNSConfig{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2511,7 +2638,7 @@ func (x *DNSConfig) String() string { func (*DNSConfig) ProtoMessage() {} func (x *DNSConfig) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[28] + mi := &file_management_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2524,7 +2651,7 @@ func (x *DNSConfig) ProtoReflect() protoreflect.Message { // Deprecated: Use DNSConfig.ProtoReflect.Descriptor instead. func (*DNSConfig) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{28} + return file_management_proto_rawDescGZIP(), []int{29} } func (x *DNSConfig) GetServiceEnable() bool { @@ -2568,7 +2695,7 @@ type CustomZone struct { func (x *CustomZone) Reset() { *x = CustomZone{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2581,7 +2708,7 @@ func (x *CustomZone) String() string { func (*CustomZone) ProtoMessage() {} func (x *CustomZone) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[29] + mi := &file_management_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2594,7 +2721,7 @@ func (x *CustomZone) ProtoReflect() protoreflect.Message { // Deprecated: Use CustomZone.ProtoReflect.Descriptor instead. func (*CustomZone) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{29} + return file_management_proto_rawDescGZIP(), []int{30} } func (x *CustomZone) GetDomain() string { @@ -2627,7 +2754,7 @@ type SimpleRecord struct { func (x *SimpleRecord) Reset() { *x = SimpleRecord{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2640,7 +2767,7 @@ func (x *SimpleRecord) String() string { func (*SimpleRecord) ProtoMessage() {} func (x *SimpleRecord) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[30] + mi := &file_management_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2653,7 +2780,7 @@ func (x *SimpleRecord) ProtoReflect() protoreflect.Message { // Deprecated: Use SimpleRecord.ProtoReflect.Descriptor instead. func (*SimpleRecord) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{30} + return file_management_proto_rawDescGZIP(), []int{31} } func (x *SimpleRecord) GetName() string { @@ -2706,7 +2833,7 @@ type NameServerGroup struct { func (x *NameServerGroup) Reset() { *x = NameServerGroup{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2719,7 +2846,7 @@ func (x *NameServerGroup) String() string { func (*NameServerGroup) ProtoMessage() {} func (x *NameServerGroup) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[31] + mi := &file_management_proto_msgTypes[32] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2732,7 +2859,7 @@ func (x *NameServerGroup) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServerGroup.ProtoReflect.Descriptor instead. func (*NameServerGroup) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31} + return file_management_proto_rawDescGZIP(), []int{32} } func (x *NameServerGroup) GetNameServers() []*NameServer { @@ -2777,7 +2904,7 @@ type NameServer struct { func (x *NameServer) Reset() { *x = NameServer{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2790,7 +2917,7 @@ func (x *NameServer) String() string { func (*NameServer) ProtoMessage() {} func (x *NameServer) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[32] + mi := &file_management_proto_msgTypes[33] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2803,7 +2930,7 @@ func (x *NameServer) ProtoReflect() protoreflect.Message { // Deprecated: Use NameServer.ProtoReflect.Descriptor instead. func (*NameServer) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{32} + return file_management_proto_rawDescGZIP(), []int{33} } func (x *NameServer) GetIP() string { @@ -2846,7 +2973,7 @@ type FirewallRule struct { func (x *FirewallRule) Reset() { *x = FirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2859,7 +2986,7 @@ func (x *FirewallRule) String() string { func (*FirewallRule) ProtoMessage() {} func (x *FirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[33] + mi := &file_management_proto_msgTypes[34] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2872,7 +2999,7 @@ func (x *FirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead. func (*FirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{33} + return file_management_proto_rawDescGZIP(), []int{34} } func (x *FirewallRule) GetPeerIP() string { @@ -2936,7 +3063,7 @@ type NetworkAddress struct { func (x *NetworkAddress) Reset() { *x = NetworkAddress{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2949,7 +3076,7 @@ func (x *NetworkAddress) String() string { func (*NetworkAddress) ProtoMessage() {} func (x *NetworkAddress) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[34] + mi := &file_management_proto_msgTypes[35] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2962,7 +3089,7 @@ func (x *NetworkAddress) ProtoReflect() protoreflect.Message { // Deprecated: Use NetworkAddress.ProtoReflect.Descriptor instead. func (*NetworkAddress) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{34} + return file_management_proto_rawDescGZIP(), []int{35} } func (x *NetworkAddress) GetNetIP() string { @@ -2990,7 +3117,7 @@ type Checks struct { func (x *Checks) Reset() { *x = Checks{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3003,7 +3130,7 @@ func (x *Checks) String() string { func (*Checks) ProtoMessage() {} func (x *Checks) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[35] + mi := &file_management_proto_msgTypes[36] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3016,7 +3143,7 @@ func (x *Checks) ProtoReflect() protoreflect.Message { // Deprecated: Use Checks.ProtoReflect.Descriptor instead. func (*Checks) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{35} + return file_management_proto_rawDescGZIP(), []int{36} } func (x *Checks) GetFiles() []string { @@ -3041,7 +3168,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3054,7 +3181,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[36] + mi := &file_management_proto_msgTypes[37] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3067,7 +3194,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36} + return file_management_proto_rawDescGZIP(), []int{37} } func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -3138,7 +3265,7 @@ type RouteFirewallRule struct { func (x *RouteFirewallRule) Reset() { *x = RouteFirewallRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3151,7 +3278,7 @@ func (x *RouteFirewallRule) String() string { func (*RouteFirewallRule) ProtoMessage() {} func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[37] + mi := &file_management_proto_msgTypes[38] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3164,7 +3291,7 @@ func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { // Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. func (*RouteFirewallRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{37} + return file_management_proto_rawDescGZIP(), []int{38} } func (x *RouteFirewallRule) GetSourceRanges() []string { @@ -3255,7 +3382,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3268,7 +3395,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[38] + mi := &file_management_proto_msgTypes[39] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3281,7 +3408,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{38} + return file_management_proto_rawDescGZIP(), []int{39} } func (x *ForwardingRule) GetProtocol() RuleProtocol { @@ -3324,7 +3451,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} if protoimpl.UnsafeEnabled { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3337,7 +3464,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_management_proto_msgTypes[39] + mi := &file_management_proto_msgTypes[40] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3350,7 +3477,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{36, 0} + return file_management_proto_rawDescGZIP(), []int{37, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -3438,7 +3565,7 @@ var file_management_proto_rawDesc = []byte{ 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x78, 0x69, 0x73, 0x74, 0x12, 0x2a, 0x0a, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, - 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xc1, 0x03, 0x0a, 0x05, + 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xbf, 0x05, 0x0a, 0x05, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, @@ -3466,435 +3593,465 @@ var file_management_proto_rawDesc = []byte{ 0x63, 0x6b, 0x49, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x6c, 0x61, 0x7a, 0x79, 0x43, 0x6f, - 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, - 0xf2, 0x04, 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, - 0x74, 0x61, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, - 0x0a, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, - 0x4f, 0x53, 0x12, 0x16, 0x0a, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, - 0x72, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, - 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, - 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, - 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x12, 0x24, 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, - 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, - 0x69, 0x6f, 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, - 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, - 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, - 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, - 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, - 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, - 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, - 0x72, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, - 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, - 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, - 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, - 0x65, 0x6e, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x46, 0x69, 0x6c, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, - 0x6c, 0x61, 0x67, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, - 0x6c, 0x61, 0x67, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, - 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, - 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, - 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x2a, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x12, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, - 0x63, 0x6b, 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, - 0x65, 0x79, 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, - 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, - 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, - 0xff, 0x01, 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, - 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, - 0x35, 0x0a, 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, - 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, - 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, - 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, - 0x72, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, - 0x77, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, - 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, - 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, - 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, - 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, - 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, - 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, - 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, - 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, - 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, - 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, - 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, - 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, - 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, - 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, - 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, - 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, - 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, - 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x7d, 0x0a, 0x13, 0x50, - 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, - 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, - 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, - 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, - 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, - 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, - 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, - 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, - 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, - 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, - 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, - 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, - 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, - 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, - 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, - 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, - 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, - 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x24, 0x0a, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x52, 0x6f, 0x6f, 0x74, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, + 0x48, 0x52, 0x6f, 0x6f, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, + 0x53, 0x48, 0x53, 0x46, 0x54, 0x50, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x65, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x53, 0x46, 0x54, 0x50, 0x12, 0x42, 0x0a, 0x1c, 0x65, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x6f, 0x72, + 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x18, 0x0d, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x1c, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x4c, 0x6f, 0x63, 0x61, + 0x6c, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12, + 0x44, 0x0a, 0x1d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x52, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, + 0x18, 0x0e, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, + 0x48, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, + 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x26, 0x0a, 0x0e, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, + 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x64, + 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x22, 0xf2, 0x04, + 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, + 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x67, 0x6f, 0x4f, 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53, + 0x12, 0x16, 0x0a, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, + 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x62, + 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x12, 0x1c, 0x0a, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, + 0x0a, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, + 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x0c, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, + 0x6d, 0x62, 0x65, 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, + 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x79, + 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, 0x0a, 0x0f, + 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x18, + 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, + 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, + 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, + 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, + 0x74, 0x12, 0x26, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x10, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, + 0x6c, 0x65, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, 0x6c, 0x61, + 0x67, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, 0x6c, 0x61, + 0x67, 0x73, 0x22, 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2a, 0x0a, + 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, + 0x73, 0x52, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x38, 0x0a, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, + 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xa8, 0x02, + 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a, + 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, + 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74, + 0x75, 0x72, 0x6e, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69, + 0x67, 0x6e, 0x61, 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, + 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x12, + 0x27, 0x0a, 0x03, 0x6a, 0x77, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x03, 0x6a, 0x77, 0x74, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, + 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, + 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, + 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, + 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, + 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, + 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, + 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, + 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, + 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, + 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, + 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, + 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x22, 0x85, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, + 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, + 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6b, 0x65, 0x79, 0x73, + 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x6d, 0x61, 0x78, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x6d, + 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, + 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, + 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, + 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, + 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, + 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, + 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, + 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, + 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, + 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, + 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, + 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x22, + 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, + 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, + 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, + 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, + 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, + 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, + 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, - 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, - 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, - 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, - 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, - 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, - 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, - 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, - 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, - 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, - 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, - 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, - 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, - 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, - 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, - 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, - 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, - 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, - 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, - 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, - 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, - 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, - 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, - 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, - 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, - 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, - 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x24, - 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, - 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, - 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, - 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, - 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, - 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, - 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, - 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, - 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, - 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, - 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, - 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, - 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, - 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, - 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, - 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, - 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, - 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, - 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, - 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, - 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, - 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, - 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, - 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, - 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, - 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, - 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, - 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, - 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, - 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, - 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, - 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, - 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, - 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, - 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, - 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, - 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, - 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, - 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, + 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, + 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, + 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, + 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, + 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, + 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, + 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, + 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, + 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, + 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, + 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, + 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, + 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, + 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, + 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, + 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, + 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, + 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, + 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, + 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, + 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, + 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, + 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, + 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, + 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, + 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, + 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, + 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, + 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, + 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, + 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, + 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, + 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, + 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, + 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, + 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, + 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, + 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, + 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, + 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, + 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, + 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, + 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, + 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, + 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, + 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, + 0x65, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, + 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, + 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, + 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, + 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, + 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, + 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, + 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, + 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, + 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, + 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, + 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, + 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, + 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, + 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, + 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, + 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, + 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, + 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, + 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, + 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, + 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, + 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, + 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, + 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, + 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, + 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, + 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, + 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, + 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, + 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, + 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, + 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, + 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, + 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, + 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, + 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, + 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, - 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, + 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, + 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3910,7 +4067,7 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 40) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 41) var file_management_proto_goTypes = []interface{}{ (RuleProtocol)(0), // 0: management.RuleProtocol (RuleDirection)(0), // 1: management.RuleDirection @@ -3934,107 +4091,110 @@ var file_management_proto_goTypes = []interface{}{ (*HostConfig)(nil), // 19: management.HostConfig (*RelayConfig)(nil), // 20: management.RelayConfig (*FlowConfig)(nil), // 21: management.FlowConfig - (*ProtectedHostConfig)(nil), // 22: management.ProtectedHostConfig - (*PeerConfig)(nil), // 23: management.PeerConfig - (*NetworkMap)(nil), // 24: management.NetworkMap - (*RemotePeerConfig)(nil), // 25: management.RemotePeerConfig - (*SSHConfig)(nil), // 26: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 27: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 28: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 29: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 30: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 31: management.ProviderConfig - (*Route)(nil), // 32: management.Route - (*DNSConfig)(nil), // 33: management.DNSConfig - (*CustomZone)(nil), // 34: management.CustomZone - (*SimpleRecord)(nil), // 35: management.SimpleRecord - (*NameServerGroup)(nil), // 36: management.NameServerGroup - (*NameServer)(nil), // 37: management.NameServer - (*FirewallRule)(nil), // 38: management.FirewallRule - (*NetworkAddress)(nil), // 39: management.NetworkAddress - (*Checks)(nil), // 40: management.Checks - (*PortInfo)(nil), // 41: management.PortInfo - (*RouteFirewallRule)(nil), // 42: management.RouteFirewallRule - (*ForwardingRule)(nil), // 43: management.ForwardingRule - (*PortInfo_Range)(nil), // 44: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 45: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 46: google.protobuf.Duration + (*JWTConfig)(nil), // 22: management.JWTConfig + (*ProtectedHostConfig)(nil), // 23: management.ProtectedHostConfig + (*PeerConfig)(nil), // 24: management.PeerConfig + (*NetworkMap)(nil), // 25: management.NetworkMap + (*RemotePeerConfig)(nil), // 26: management.RemotePeerConfig + (*SSHConfig)(nil), // 27: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 28: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 29: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 30: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 31: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 32: management.ProviderConfig + (*Route)(nil), // 33: management.Route + (*DNSConfig)(nil), // 34: management.DNSConfig + (*CustomZone)(nil), // 35: management.CustomZone + (*SimpleRecord)(nil), // 36: management.SimpleRecord + (*NameServerGroup)(nil), // 37: management.NameServerGroup + (*NameServer)(nil), // 38: management.NameServer + (*FirewallRule)(nil), // 39: management.FirewallRule + (*NetworkAddress)(nil), // 40: management.NetworkAddress + (*Checks)(nil), // 41: management.Checks + (*PortInfo)(nil), // 42: management.PortInfo + (*RouteFirewallRule)(nil), // 43: management.RouteFirewallRule + (*ForwardingRule)(nil), // 44: management.ForwardingRule + (*PortInfo_Range)(nil), // 45: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 46: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 47: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ 14, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta 18, // 1: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig - 23, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 25, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 24, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 40, // 5: management.SyncResponse.Checks:type_name -> management.Checks + 24, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig + 26, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 25, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 41, // 5: management.SyncResponse.Checks:type_name -> management.Checks 14, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta 14, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta 10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 39, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 40, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress 11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment 12, // 11: management.PeerSystemMeta.files:type_name -> management.File 13, // 12: management.PeerSystemMeta.flags:type_name -> management.Flags 18, // 13: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig - 23, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 40, // 15: management.LoginResponse.Checks:type_name -> management.Checks - 45, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 24, // 14: management.LoginResponse.peerConfig:type_name -> management.PeerConfig + 41, // 15: management.LoginResponse.Checks:type_name -> management.Checks + 46, // 16: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 19, // 17: management.NetbirdConfig.stuns:type_name -> management.HostConfig - 22, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig + 23, // 18: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig 19, // 19: management.NetbirdConfig.signal:type_name -> management.HostConfig 20, // 20: management.NetbirdConfig.relay:type_name -> management.RelayConfig 21, // 21: management.NetbirdConfig.flow:type_name -> management.FlowConfig - 3, // 22: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 46, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration - 19, // 24: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 26, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 23, // 26: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 25, // 27: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 32, // 28: management.NetworkMap.Routes:type_name -> management.Route - 33, // 29: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 25, // 30: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 38, // 31: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 42, // 32: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 43, // 33: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 26, // 34: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 4, // 35: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 31, // 36: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 31, // 37: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 36, // 38: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 34, // 39: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 35, // 40: management.CustomZone.Records:type_name -> management.SimpleRecord - 37, // 41: management.NameServerGroup.NameServers:type_name -> management.NameServer - 1, // 42: management.FirewallRule.Direction:type_name -> management.RuleDirection - 2, // 43: management.FirewallRule.Action:type_name -> management.RuleAction - 0, // 44: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 41, // 45: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 44, // 46: management.PortInfo.range:type_name -> management.PortInfo.Range - 2, // 47: management.RouteFirewallRule.action:type_name -> management.RuleAction - 0, // 48: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 41, // 49: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 0, // 50: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 41, // 51: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 41, // 52: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 5, // 53: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 54: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 17, // 55: management.ManagementService.GetServerKey:input_type -> management.Empty - 17, // 56: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 57: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 58: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 59: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 60: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 5, // 61: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 62: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 16, // 63: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 17, // 64: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 65: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 66: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 17, // 67: management.ManagementService.SyncMeta:output_type -> management.Empty - 17, // 68: management.ManagementService.Logout:output_type -> management.Empty - 61, // [61:69] is the sub-list for method output_type - 53, // [53:61] is the sub-list for method input_type - 53, // [53:53] is the sub-list for extension type_name - 53, // [53:53] is the sub-list for extension extendee - 0, // [0:53] is the sub-list for field type_name + 22, // 22: management.NetbirdConfig.jwt:type_name -> management.JWTConfig + 3, // 23: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 47, // 24: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 19, // 25: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 27, // 26: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 24, // 27: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 26, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 33, // 29: management.NetworkMap.Routes:type_name -> management.Route + 34, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 26, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 39, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 43, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 44, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 27, // 35: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 22, // 36: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 4, // 37: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 32, // 38: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 32, // 39: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 37, // 40: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 35, // 41: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 36, // 42: management.CustomZone.Records:type_name -> management.SimpleRecord + 38, // 43: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 44: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 45: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 46: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 42, // 47: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 45, // 48: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 49: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 50: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 42, // 51: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 0, // 52: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 42, // 53: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 42, // 54: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 5, // 55: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 56: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 17, // 57: management.ManagementService.GetServerKey:input_type -> management.Empty + 17, // 58: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 59: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 60: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 61: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 62: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 5, // 63: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 64: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 16, // 65: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 17, // 66: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 67: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 68: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 17, // 69: management.ManagementService.SyncMeta:output_type -> management.Empty + 17, // 70: management.ManagementService.Logout:output_type -> management.Empty + 63, // [63:71] is the sub-list for method output_type + 55, // [55:63] is the sub-list for method input_type + 55, // [55:55] is the sub-list for extension type_name + 55, // [55:55] is the sub-list for extension extendee + 0, // [0:55] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -4248,7 +4408,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProtectedHostConfig); i { + switch v := v.(*JWTConfig); i { case 0: return &v.state case 1: @@ -4260,7 +4420,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PeerConfig); i { + switch v := v.(*ProtectedHostConfig); i { case 0: return &v.state case 1: @@ -4272,7 +4432,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkMap); i { + switch v := v.(*PeerConfig); i { case 0: return &v.state case 1: @@ -4284,7 +4444,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemotePeerConfig); i { + switch v := v.(*NetworkMap); i { case 0: return &v.state case 1: @@ -4296,7 +4456,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHConfig); i { + switch v := v.(*RemotePeerConfig); i { case 0: return &v.state case 1: @@ -4308,7 +4468,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlowRequest); i { + switch v := v.(*SSHConfig); i { case 0: return &v.state case 1: @@ -4320,7 +4480,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlow); i { + switch v := v.(*DeviceAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4332,7 +4492,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlowRequest); i { + switch v := v.(*DeviceAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4344,7 +4504,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlow); i { + switch v := v.(*PKCEAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -4356,7 +4516,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProviderConfig); i { + switch v := v.(*PKCEAuthorizationFlow); i { case 0: return &v.state case 1: @@ -4368,7 +4528,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*ProviderConfig); i { case 0: return &v.state case 1: @@ -4380,7 +4540,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DNSConfig); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -4392,7 +4552,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CustomZone); i { + switch v := v.(*DNSConfig); i { case 0: return &v.state case 1: @@ -4404,7 +4564,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SimpleRecord); i { + switch v := v.(*CustomZone); i { case 0: return &v.state case 1: @@ -4416,7 +4576,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServerGroup); i { + switch v := v.(*SimpleRecord); i { case 0: return &v.state case 1: @@ -4428,7 +4588,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServer); i { + switch v := v.(*NameServerGroup); i { case 0: return &v.state case 1: @@ -4440,7 +4600,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*FirewallRule); i { + switch v := v.(*NameServer); i { case 0: return &v.state case 1: @@ -4452,7 +4612,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkAddress); i { + switch v := v.(*FirewallRule); i { case 0: return &v.state case 1: @@ -4464,7 +4624,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Checks); i { + switch v := v.(*NetworkAddress); i { case 0: return &v.state case 1: @@ -4476,7 +4636,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PortInfo); i { + switch v := v.(*Checks); i { case 0: return &v.state case 1: @@ -4488,7 +4648,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RouteFirewallRule); i { + switch v := v.(*PortInfo); i { case 0: return &v.state case 1: @@ -4500,7 +4660,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ForwardingRule); i { + switch v := v.(*RouteFirewallRule); i { case 0: return &v.state case 1: @@ -4512,6 +4672,18 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ForwardingRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*PortInfo_Range); i { case 0: return &v.state @@ -4524,7 +4696,7 @@ func file_management_proto_init() { } } } - file_management_proto_msgTypes[36].OneofWrappers = []interface{}{ + file_management_proto_msgTypes[37].OneofWrappers = []interface{}{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } @@ -4534,7 +4706,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 40, + NumMessages: 41, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 3982ea2af..16737cf58 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -146,6 +146,12 @@ message Flags { bool blockInbound = 9; bool lazyConnectionEnabled = 10; + + bool enableSSHRoot = 11; + bool enableSSHSFTP = 12; + bool enableSSHLocalPortForwarding = 13; + bool enableSSHRemotePortForwarding = 14; + bool disableSSHAuth = 15; } // PeerSystemMeta is machine meta data like OS and version. @@ -202,6 +208,8 @@ message NetbirdConfig { RelayConfig relay = 4; FlowConfig flow = 5; + + JWTConfig jwt = 6; } // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) @@ -240,6 +248,14 @@ message FlowConfig { bool dnsCollection = 8; } +// JWTConfig represents JWT authentication configuration +message JWTConfig { + string issuer = 1; + string audience = 2; + string keysLocation = 3; + int64 maxTokenAge = 4; +} + // ProtectedHostConfig is similar to HostConfig but has additional user and password // Mostly used for TURN servers message ProtectedHostConfig { @@ -335,6 +351,8 @@ message SSHConfig { // sshPubKey is a SSH public key of a peer to be added to authorized_hosts. // This property should be ignore if SSHConfig comes from PeerConfig. bytes sshPubKey = 2; + + JWTConfig jwtConfig = 3; } // DeviceAuthorizationFlowRequest empty struct for future expansion diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 967e18d79..c057ef089 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -11,8 +11,8 @@ import ( "github.com/quic-go/quic-go" log "github.com/sirupsen/logrus" - quictls "github.com/netbirdio/netbird/shared/relay/tls" nbnet "github.com/netbirdio/netbird/client/net" + quictls "github.com/netbirdio/netbird/shared/relay/tls" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 66fff3447..37b189e05 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -14,9 +14,9 @@ import ( "github.com/coder/websocket" log "github.com/sirupsen/logrus" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/constants.go b/shared/relay/constants.go index 3c7c3cd29..0f2a27610 100644 --- a/shared/relay/constants.go +++ b/shared/relay/constants.go @@ -3,4 +3,4 @@ package relay const ( // WebSocketURLPath is the path for the websocket relay connection WebSocketURLPath = "/relay" -) \ No newline at end of file +) diff --git a/version/url_windows.go b/version/url_windows.go index 14fdb7ae6..a0fb6e5dd 100644 --- a/version/url_windows.go +++ b/version/url_windows.go @@ -6,7 +6,7 @@ import ( ) const ( - urlWinExe = "https://pkgs.netbird.io/windows/x64" + urlWinExe = "https://pkgs.netbird.io/windows/x64" urlWinExeArm = "https://pkgs.netbird.io/windows/arm64" ) @@ -18,11 +18,11 @@ func DownloadUrl() string { if err != nil { return downloadURL } - + url := urlWinExe if runtime.GOARCH == "arm64" { url = urlWinExeArm } - + return url } From 4eeb2d8debcb7c6e2c88f9c8fbc94958f8a9f836 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Mon, 17 Nov 2025 18:20:30 +0100 Subject: [PATCH 078/120] [management] added exception on not appending route firewall rules if we have all wildcard (#4801) --- management/server/types/networkmapbuilder.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 41aaa7fc8..6361e2e93 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -22,10 +22,11 @@ import ( ) const ( - allPeers = "0.0.0.0" - fw = "fw:" - rfw = "route-fw:" - nr = "network-resource-" + allPeers = "0.0.0.0" + allWildcard = "0.0.0.0/0" + v6AllWildcard = "::/0" + fw = "fw:" + rfw = "route-fw:" ) type NetworkMapCache struct { @@ -1640,6 +1641,10 @@ func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, } if string(rule.RouteID) == update.RuleID { + if hasWildcard := slices.Contains(rule.SourceRanges, allWildcard) || slices.Contains(rule.SourceRanges, v6AllWildcard); hasWildcard { + break + } + sourceIP := update.AddSourceIP if strings.Contains(sourceIP, ":") { From 60f4d5f9b0ab6058ccfc7805c6143d1292b4bcec Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 18 Nov 2025 12:41:17 +0100 Subject: [PATCH 079/120] [client] Revert migrate deprecated grpc client code #4805 --- client/grpc/dialer.go | 42 +++++++----------------------------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 7763f2417..54966b50e 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "fmt" "runtime" "time" @@ -12,7 +11,6 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -20,9 +18,6 @@ import ( "github.com/netbirdio/netbird/util/embeddedroots" ) -// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready -var ErrConnectionShutdown = errors.New("connection shutdown before ready") - // Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() @@ -31,26 +26,6 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -// waitForConnectionReady blocks until the connection becomes ready or fails. -// Returns an error if the connection times out, is cancelled, or enters shutdown state. -func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error { - conn.Connect() - - state := conn.GetState() - for state != connectivity.Ready && state != connectivity.Shutdown { - if !conn.WaitForStateChange(ctx, state) { - return fmt.Errorf("wait state change from %s: %w", state, ctx.Err()) - } - state = conn.GetState() - } - - if state == connectivity.Shutdown { - return ErrConnectionShutdown - } - - return nil -} - // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { @@ -68,25 +43,22 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone })) } - conn, err := grpc.NewClient( + connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + conn, err := grpc.DialContext( + connCtx, addr, transportOption, WithCustomDialer(tlsEnabled, component), + grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, }), ) if err != nil { - return nil, fmt.Errorf("new client: %w", err) - } - - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - if err := waitForConnectionReady(ctx, conn); err != nil { - _ = conn.Close() - return nil, err + return nil, fmt.Errorf("dial context: %w", err) } return conn, nil From 05cbead39b21c5b2c69a2df020ede0746058faad Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 18 Nov 2025 17:15:57 +0100 Subject: [PATCH 080/120] [management] Fix direct peer networks route (#4802) --- management/server/types/account.go | 11 ++ management/server/types/networkmapbuilder.go | 152 +++++++++++++++---- 2 files changed, 130 insertions(+), 33 deletions(-) diff --git a/management/server/types/account.go b/management/server/types/account.go index 8797e1fa3..51313819f 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1242,6 +1242,13 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID } } } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := validatedPeersMap[rule.SourceResource.ID] + if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) for pID := range distPeersWithPolicy { @@ -1587,6 +1594,10 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st sourcePeers[peer] = struct{}{} } } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers[rule.SourceResource.ID] = struct{}{} + } } } diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 6361e2e93..5790f1646 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -858,6 +858,14 @@ func (b *NetworkMapBuilder) getRulePeers( } } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := validatedPeersMap[rule.SourceResource.ID] + if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) for pID := range distPeersWithPolicy { peer := b.cache.globalPeers[pID] @@ -1287,24 +1295,54 @@ func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPee if !rule.Enabled { continue } + var peerInSources, peerInDestinations bool - peerInSources := b.isPeerInGroups(rule.Sources, peerGroups) - peerInDestinations := b.isPeerInGroups(rule.Destinations, peerGroups) + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID == newPeerID { + peerInSources = true + } else { + peerInSources = b.isPeerInGroups(rule.Sources, peerGroups) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID == newPeerID { + peerInDestinations = true + } else { + peerInDestinations = b.isPeerInGroups(rule.Destinations, peerGroups) + } if peerInSources { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + if len(rule.Destinations) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionIN) + } } if peerInDestinations { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + if len(rule.Sources) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) + } } if rule.Bidirectional { if peerInSources { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + if len(rule.Destinations) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) + } } if peerInDestinations { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + if len(rule.Sources) > 0 { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionIN) + } } } } @@ -1566,42 +1604,90 @@ func (b *NetworkMapBuilder) addUpdateForPeersInGroups( if _, ok := b.validatedPeers[peerID]; !ok { continue } - delta := updates[peerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: peerID, - AddConnectedPeer: newPeerID, - AddFirewallRules: make([]*FirewallRuleDelta, 0), - } - updates[peerID] = delta + targetPeer := b.cache.globalPeers[peerID] + if targetPeer == nil { + continue } + peerIPForRule := fr.PeerIP if all { - fr.PeerIP = allPeers + peerIPForRule = allPeers } - if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { - expandedRules := expandPortsAndRanges(*fr, rule, b.cache.globalPeers[peerID]) - for _, expandedRule := range expandedRules { - ruleID := b.generateFirewallRuleID(expandedRule) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: expandedRule, - RuleID: ruleID, - Direction: direction, - }) - } - } else { - ruleID := b.generateFirewallRuleID(fr) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: fr, - RuleID: ruleID, - Direction: direction, - }) - } + b.addOrUpdateFirewallRuleInDelta(updates, peerID, newPeerID, rule, direction, fr, peerIPForRule, targetPeer) } } } +func (b *NetworkMapBuilder) addUpdateForDirectPeerResource( + updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, + rule *PolicyRule, direction int, +) { + if targetPeerID == newPeerID { + return + } + + if _, ok := b.validatedPeers[targetPeerID]; !ok { + return + } + + newPeer := b.cache.globalPeers[newPeerID] + if newPeer == nil { + return + } + + targetPeer := b.cache.globalPeers[targetPeerID] + if targetPeer == nil { + return + } + + fr := &FirewallRule{ + PolicyID: rule.ID, + PeerIP: newPeer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer) +} + +func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta( + updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, + rule *PolicyRule, direction int, baseRule *FirewallRule, peerIP string, targetPeer *nbpeer.Peer, +) { + delta := updates[targetPeerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: targetPeerID, + AddConnectedPeer: newPeerID, + AddFirewallRules: make([]*FirewallRuleDelta, 0), + } + updates[targetPeerID] = delta + } + + baseRule.PeerIP = peerIP + + if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { + expandedRules := expandPortsAndRanges(*baseRule, rule, targetPeer) + for _, expandedRule := range expandedRules { + ruleID := b.generateFirewallRuleID(expandedRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: expandedRule, + RuleID: ruleID, + Direction: direction, + }) + } + } else { + ruleID := b.generateFirewallRuleID(baseRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: baseRule, + RuleID: ruleID, + Direction: direction, + }) + } +} + func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 { if aclView := b.cache.peerACLs[peerID]; aclView != nil { From 3351b38434c51151a03b6e4f0dc023b8bbb28f59 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 19 Nov 2025 11:52:18 +0100 Subject: [PATCH 081/120] [management] pass config to controller (#4807) --- client/cmd/testutil_test.go | 2 +- client/internal/engine.go | 1 + client/internal/engine_test.go | 2 +- client/server/server_test.go | 2 +- .../network_map/controller/controller.go | 9 +- management/internals/server/controllers.go | 2 +- .../internals/shared/grpc/conversion.go | 30 +- management/internals/shared/grpc/server.go | 4 +- management/server/account_test.go | 3 +- management/server/dns_test.go | 3 +- .../testing/testing_tools/channel/channel.go | 3 +- management/server/management_proto_test.go | 2 +- management/server/management_test.go | 2 +- management/server/nameserver_test.go | 3 +- management/server/peer_test.go | 12 +- management/server/route_test.go | 3 +- shared/management/client/client_test.go | 2 +- shared/management/proto/management.pb.go | 849 +++++++++--------- shared/management/proto/management.proto | 2 - 19 files changed, 464 insertions(+), 472 deletions(-) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index e7b0279e8..5477653a2 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -116,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 1deb3d3cf..94c948398 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1192,6 +1192,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { + //nolint forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) if forwarderPort == 0 { forwarderPort = nbdns.ForwarderClientPort diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 3b7ff0eba..d6ab7391e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1624,7 +1624,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) - networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err diff --git a/client/server/server_test.go b/client/server/server_test.go index 96d4c0af0..f8592bc7a 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -316,7 +316,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) - networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index ad25494c7..49bb9cef3 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -19,6 +19,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" @@ -47,6 +48,7 @@ type Controller struct { updateAccountPeersBufferInterval atomic.Int64 // dnsDomain is used for peer resolution. This is appended to the peer's name dnsDomain string + config *config.Config requestBuffer account.RequestBuffer @@ -68,7 +70,7 @@ type bufferUpdate struct { var _ network_map.Controller = (*Controller)(nil) -func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller { +func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller { nMetrics, err := newMetrics(metrics.UpdateChannelMetrics()) if err != nil { log.Fatal(fmt.Errorf("error creating metrics: %w", err)) @@ -95,6 +97,7 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App integratedPeerValidator: integratedPeerValidator, settingsManager: settingsManager, dnsDomain: dnsDomain, + config: config, proxyController: proxyController, @@ -205,7 +208,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin peerGroups := account.GetPeerGroups(p.ID) start = time.Now() - update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) + update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) c.metrics.CountToSyncResponseDuration(time.Since(start)) c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) @@ -323,7 +326,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe peerGroups := account.GetPeerGroups(peerId) dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) - update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) + update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) return nil diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index b61e33688..38ec6fde6 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -70,7 +70,7 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager { func (s *BaseServer) NetworkMapController() network_map.Controller { return Create(s, func() *nmapcontroller.Controller { - return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController()) + return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController(), s.config) }) } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 7f64034df..148db4e37 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -83,7 +83,7 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken return nbConfig } -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, config *nbconfig.Config) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) @@ -92,7 +92,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } if peer.SSHEnabled { - sshConfig.JwtConfig = buildJWTConfig(config) + sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig) } return &proto.PeerConfig{ @@ -104,9 +104,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } } -func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { +func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, config), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), @@ -363,35 +363,29 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe } // buildJWTConfig constructs JWT configuration for SSH servers from management server config -func buildJWTConfig(config *nbconfig.Config) *proto.JWTConfig { - if config == nil { +func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig { + if config == nil || config.AuthAudience == "" { return nil } - if config.HttpConfig == nil || config.HttpConfig.AuthAudience == "" { - return nil - } - - issuer := strings.TrimSpace(config.HttpConfig.AuthIssuer) - if issuer == "" { - if config.DeviceAuthorizationFlow != nil { - if d := deriveIssuerFromTokenEndpoint(config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint); d != "" { - issuer = d - } + issuer := strings.TrimSpace(config.AuthIssuer) + if issuer == "" || deviceFlowConfig != nil { + if d := deriveIssuerFromTokenEndpoint(deviceFlowConfig.ProviderConfig.TokenEndpoint); d != "" { + issuer = d } } if issuer == "" { return nil } - keysLocation := strings.TrimSpace(config.HttpConfig.AuthKeysLocation) + keysLocation := strings.TrimSpace(config.AuthKeysLocation) if keysLocation == "" { keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json" } return &proto.JWTConfig{ Issuer: issuer, - Audience: config.HttpConfig.AuthAudience, + Audience: config.AuthAudience, KeysLocation: keysLocation, } } diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 4364272a0..0aadadf84 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -646,7 +646,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow), Checks: toProtocolChecks(ctx, postureChecks), } @@ -713,7 +713,7 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) + plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/account_test.go b/management/server/account_test.go index 10d718bbf..340e8db18 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/server/config" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" @@ -2958,7 +2959,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, nil, err diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 6b7a36c20..99b09566a 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -12,6 +12,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -222,7 +223,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index ac165aeb2..e292a7d6c 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -71,7 +72,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee ctx := context.Background() requestBuffer := server.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), &config.Config{}) am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 496be9caa..42311d944 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -363,7 +363,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) diff --git a/management/server/management_test.go b/management/server/management_test.go index c485f16b4..2350b225b 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -205,7 +205,7 @@ func startServer( ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) - networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := server.BuildManager( context.Background(), diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 51738c106..a4574c978 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -13,6 +13,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -791,7 +792,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 21a6952a9..2d09f5200 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1058,7 +1058,6 @@ func testUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel - assert.Nil(t, update.Update.NetbirdConfig) assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } @@ -1177,7 +1176,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &cache.DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := grpc.ToSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) + response := grpc.ToSyncResponse(context.Background(), config, config.HttpConfig, config.DeviceAuthorizationFlow, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config @@ -1228,6 +1227,7 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID) // assert network map DNSConfig assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable) + //nolint assert.Equal(t, int64(dnsForwarderPort), response.NetworkMap.DNSConfig.ForwarderPort) assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones)) assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups)) @@ -1290,7 +1290,7 @@ func Test_RegisterPeerByUser(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1375,7 +1375,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1528,7 +1528,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1608,7 +1608,7 @@ func Test_LoginPeer(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) diff --git a/management/server/route_test.go b/management/server/route_test.go index 7ff362bc6..5c8b636bc 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -1290,7 +1291,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index f98e76ce7..9e08317f6 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -117,7 +117,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index ca12cf48c..45211b7f4 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.32.1 +// protoc v6.33.0 // source: management.proto package proto @@ -1312,7 +1312,6 @@ type NetbirdConfig struct { Signal *HostConfig `protobuf:"bytes,3,opt,name=signal,proto3" json:"signal,omitempty"` Relay *RelayConfig `protobuf:"bytes,4,opt,name=relay,proto3" json:"relay,omitempty"` Flow *FlowConfig `protobuf:"bytes,5,opt,name=flow,proto3" json:"flow,omitempty"` - Jwt *JWTConfig `protobuf:"bytes,6,opt,name=jwt,proto3" json:"jwt,omitempty"` } func (x *NetbirdConfig) Reset() { @@ -1382,13 +1381,6 @@ func (x *NetbirdConfig) GetFlow() *FlowConfig { return nil } -func (x *NetbirdConfig) GetJwt() *JWTConfig { - if x != nil { - return x.Jwt - } - return nil -} - // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) type HostConfig struct { state protoimpl.MessageState @@ -2619,7 +2611,8 @@ type DNSConfig struct { ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"` NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"` CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"` - ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"` + // Deprecated: Do not use. + ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"` } func (x *DNSConfig) Reset() { @@ -2675,6 +2668,7 @@ func (x *DNSConfig) GetCustomZones() []*CustomZone { return nil } +// Deprecated: Do not use. func (x *DNSConfig) GetForwarderPort() int64 { if x != nil { return x.ForwarderPort @@ -3668,7 +3662,7 @@ var file_management_proto_rawDesc = []byte{ 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xa8, 0x02, + 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xff, 0x01, 0x0a, 0x0d, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, @@ -3684,374 +3678,372 @@ var file_management_proto_rawDesc = []byte{ 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, - 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x12, - 0x27, 0x0a, 0x03, 0x6a, 0x77, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x03, 0x6a, 0x77, 0x74, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, - 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, - 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, - 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, - 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, - 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, - 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, - 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, - 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, - 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, - 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, - 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, - 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, - 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, - 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, - 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x22, 0x85, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, - 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, - 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6b, 0x65, 0x79, 0x73, - 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x6d, 0x61, 0x78, 0x54, - 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x6d, - 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, - 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, - 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, - 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, - 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, - 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, - 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, - 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, - 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, + 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x22, + 0x98, 0x01, 0x0a, 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, + 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, + 0x12, 0x3b, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, + 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, + 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, + 0x54, 0x54, 0x50, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, + 0x12, 0x08, 0x0a, 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, + 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, + 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, + 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, + 0x6f, 0x77, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, + 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, + 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, + 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, + 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, + 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, + 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, + 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x85, 0x01, 0x0a, 0x09, 0x4a, 0x57, + 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, + 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, + 0x1a, 0x0a, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, + 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x20, 0x0a, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, + 0x65, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, + 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, + 0x22, 0x93, 0x02, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, + 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, - 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x22, - 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, - 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, - 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, - 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, - 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, - 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, - 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, - 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, - 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, - 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, - 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, - 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, - 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, - 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, - 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, - 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, - 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, - 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, - 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, - 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, - 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, - 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, - 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, - 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, - 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, - 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, - 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, - 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, - 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, - 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, - 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, - 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, - 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, - 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, - 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, - 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, - 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, - 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, - 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, - 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, - 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, - 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, - 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, - 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, - 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, - 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, - 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, - 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, - 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, - 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, - 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, - 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, - 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, - 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, - 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, - 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, - 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, - 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, - 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, - 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, - 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, - 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, - 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, - 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, - 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, - 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, - 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, - 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, - 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, - 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, - 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, - 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, - 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, - 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, - 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, - 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, - 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, - 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, - 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, - 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, - 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, - 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, - 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, - 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, - 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, - 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, - 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, - 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, - 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, - 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, - 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, - 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, - 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, - 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, + 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, + 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, + 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, + 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x22, 0xb9, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, + 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, + 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, + 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, + 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, + 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, + 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, + 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, + 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, + 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, + 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, + 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, + 0x65, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, + 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, + 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, + 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, + 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, 0x0c, + 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, + 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, + 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, 0x6a, + 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, + 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, + 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, + 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, + 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x22, 0xb8, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, + 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, + 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, + 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, + 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, + 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, + 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, + 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, + 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, + 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, + 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, + 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, + 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, + 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, + 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, + 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, + 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, + 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, + 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, + 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, + 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, + 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, + 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, + 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, + 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, + 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, + 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, + 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, + 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, + 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, + 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, + 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, - 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, - 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, - 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, - 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, - 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, - 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, - 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, - 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, - 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, - 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, + 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, + 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, + 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, + 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, + 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, + 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, + 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, + 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, + 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, + 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, + 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, + 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, + 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, + 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, + 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, + 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, + 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, + 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, + 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, + 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, + 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, + 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, + 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, + 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, + 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, + 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, + 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, + 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, + 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, + 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, + 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, - 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, + 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, - 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, + 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, + 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, + 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, + 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, + 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -4141,60 +4133,59 @@ var file_management_proto_depIdxs = []int32{ 19, // 19: management.NetbirdConfig.signal:type_name -> management.HostConfig 20, // 20: management.NetbirdConfig.relay:type_name -> management.RelayConfig 21, // 21: management.NetbirdConfig.flow:type_name -> management.FlowConfig - 22, // 22: management.NetbirdConfig.jwt:type_name -> management.JWTConfig - 3, // 23: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 47, // 24: management.FlowConfig.interval:type_name -> google.protobuf.Duration - 19, // 25: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 27, // 26: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 24, // 27: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 26, // 28: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 33, // 29: management.NetworkMap.Routes:type_name -> management.Route - 34, // 30: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 26, // 31: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 39, // 32: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 43, // 33: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 44, // 34: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 27, // 35: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 22, // 36: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig - 4, // 37: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 32, // 38: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 32, // 39: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 37, // 40: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 35, // 41: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 36, // 42: management.CustomZone.Records:type_name -> management.SimpleRecord - 38, // 43: management.NameServerGroup.NameServers:type_name -> management.NameServer - 1, // 44: management.FirewallRule.Direction:type_name -> management.RuleDirection - 2, // 45: management.FirewallRule.Action:type_name -> management.RuleAction - 0, // 46: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 42, // 47: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 45, // 48: management.PortInfo.range:type_name -> management.PortInfo.Range - 2, // 49: management.RouteFirewallRule.action:type_name -> management.RuleAction - 0, // 50: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 42, // 51: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 0, // 52: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 42, // 53: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 42, // 54: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 5, // 55: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 56: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 17, // 57: management.ManagementService.GetServerKey:input_type -> management.Empty - 17, // 58: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 59: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 60: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 61: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 62: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 5, // 63: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 64: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 16, // 65: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 17, // 66: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 67: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 68: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 17, // 69: management.ManagementService.SyncMeta:output_type -> management.Empty - 17, // 70: management.ManagementService.Logout:output_type -> management.Empty - 63, // [63:71] is the sub-list for method output_type - 55, // [55:63] is the sub-list for method input_type - 55, // [55:55] is the sub-list for extension type_name - 55, // [55:55] is the sub-list for extension extendee - 0, // [0:55] is the sub-list for field type_name + 3, // 22: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 47, // 23: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 19, // 24: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 27, // 25: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 24, // 26: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 26, // 27: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 33, // 28: management.NetworkMap.Routes:type_name -> management.Route + 34, // 29: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 26, // 30: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 39, // 31: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 43, // 32: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 44, // 33: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 27, // 34: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 22, // 35: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 4, // 36: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 32, // 37: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 32, // 38: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 37, // 39: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 35, // 40: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 36, // 41: management.CustomZone.Records:type_name -> management.SimpleRecord + 38, // 42: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 43: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 44: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 45: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 42, // 46: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 45, // 47: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 48: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 49: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 42, // 50: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 0, // 51: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 42, // 52: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 42, // 53: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 5, // 54: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 55: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 17, // 56: management.ManagementService.GetServerKey:input_type -> management.Empty + 17, // 57: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 58: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 59: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 60: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 61: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 5, // 62: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 63: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 16, // 64: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 17, // 65: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 66: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 67: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 17, // 68: management.ManagementService.SyncMeta:output_type -> management.Empty + 17, // 69: management.ManagementService.Logout:output_type -> management.Empty + 62, // [62:70] is the sub-list for method output_type + 54, // [54:62] is the sub-list for method input_type + 54, // [54:54] is the sub-list for extension type_name + 54, // [54:54] is the sub-list for extension extendee + 0, // [0:54] is the sub-list for field type_name } func init() { file_management_proto_init() } diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 16737cf58..8a297e75e 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -208,8 +208,6 @@ message NetbirdConfig { RelayConfig relay = 4; FlowConfig flow = 5; - - JWTConfig jwt = 6; } // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) From 68f56b797d43a2bc9c3a68eed9c33866c3fcf0d6 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 19 Nov 2025 13:16:47 +0100 Subject: [PATCH 082/120] [management] Add native ssh port rule on 22 (#4810) Implements feature-aware firewall rule expansion: derives peer-supported features (native SSH, portRanges) from peer version, prefers explicit Ports over PortRanges when expanding, conditionally appends a native SSH (22022) rule when policy and peer support allow, and adds helpers plus tests for SSH expansion behavior. --- management/server/types/account.go | 90 ++++++-- management/server/types/account_test.go | 266 ++++++++++++++++++++++++ 2 files changed, 341 insertions(+), 15 deletions(-) diff --git a/management/server/types/account.go b/management/server/types/account.go index 51313819f..9e86d8936 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -40,8 +40,20 @@ const ( // firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules. firewallRuleMinPortRangesVer = "0.48.0" + // firewallRuleMinNativeSSHVer defines the minimum peer version that supports native SSH features in the firewall rules. + firewallRuleMinNativeSSHVer = "0.60.0" + + // nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections. + nativeSSHPortString = "22022" + // defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections. + defaultSSHPortString = "22" ) +type supportedFeatures struct { + nativeSSH bool + portRanges bool +} + type LookupMap map[string]struct{} // AccountMeta is a struct that contains a stripped down version of the Account object. @@ -1650,22 +1662,24 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error { // expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule { + features := peerSupportedFirewallFeatures(peer.Meta.WtVersion) + var expanded []*FirewallRule - if len(rule.Ports) > 0 { - for _, port := range rule.Ports { - fr := base - fr.Port = port - expanded = append(expanded, &fr) - } - return expanded + for _, port := range rule.Ports { + fr := base + fr.Port = port + expanded = append(expanded, &fr) } - supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion) for _, portRange := range rule.PortRanges { + // prefer PolicyRule.Ports + if len(rule.Ports) > 0 { + break + } fr := base - if supportPortRanges { + if features.portRanges { fr.PortRange = portRange } else { // Peer doesn't support port ranges, only allow single-port ranges @@ -1677,17 +1691,63 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer expanded = append(expanded, &fr) } + if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) { + expanded = addNativeSSHRule(base, expanded) + } + return expanded } -// peerSupportsPortRanges checks if the peer version supports port ranges. -func peerSupportsPortRanges(peerVer string) bool { - if strings.Contains(peerVer, "dev") { - return true +// addNativeSSHRule adds a native SSH rule (port 22022) to the expanded rules if the base rule has port 22 configured. +func addNativeSSHRule(base FirewallRule, expanded []*FirewallRule) []*FirewallRule { + shouldAdd := false + for _, fr := range expanded { + if isPortInRule(nativeSSHPortString, 22022, fr) { + return expanded + } + if isPortInRule(defaultSSHPortString, 22, fr) { + shouldAdd = true + } + } + if !shouldAdd { + return expanded } - meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) - return err == nil && meetMinVer + fr := base + fr.Port = nativeSSHPortString + return append(expanded, &fr) +} + +func isPortInRule(portString string, portInt uint16, rule *FirewallRule) bool { + return rule.Port == portString || (rule.PortRange.Start <= portInt && portInt <= rule.PortRange.End) +} + +// shouldCheckRulesForNativeSSH determines whether specific policy rules should be checked for native SSH support. +// While users can add the nativeSSHPortString, we look for cases when they used port 22 and based on SSH enabled +// in both management and client, we indicate to add the native port. +func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *nbpeer.Peer) bool { + return supportsNative && peer.SSHEnabled && peer.Meta.Flags.ServerSSHAllowed && rule.Protocol == PolicyRuleProtocolTCP +} + +// peerSupportedFirewallFeatures checks if the peer version supports port ranges. +func peerSupportedFirewallFeatures(peerVer string) supportedFeatures { + if strings.Contains(peerVer, "dev") { + return supportedFeatures{true, true} + } + + var features supportedFeatures + + meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinNativeSSHVer, peerVer) + features.nativeSSH = err == nil && meetMinVer + + if features.nativeSSH { + features.portRanges = true + } else { + meetMinVer, err = posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) + features.portRanges = err == nil && meetMinVer + } + + return features } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index 32538933a..f9aa6a1c2 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -839,6 +839,272 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, sourcePeers, 2, "expected source peers don't match") } +func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { + tests := []struct { + name string + peer *nbpeer.Peer + rule *PolicyRule + base FirewallRule + expectedPorts []string + }{ + { + name: "adds port 22022 when SSH enabled on modern peer with port 22", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "adds port 22022 once when port 22 is duplicated within policy", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22", "80", "22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "80", "22", "22022"}, + }, + { + name: "does not add 22022 for peer with old version", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 when SSHEnabled is false", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: false, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 when ServerSSHAllowed is false", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: false}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 for UDP protocol", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolUDP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "udp"}, + expectedPorts: []string{"22"}, + }, + { + name: "does not add 22022 when port 22 not in rule", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"80", "443"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"80", "443"}, + }, + { + name: "does not duplicate 22022 when already present", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22", "22022"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "does not duplicate 22022 when already within a port range", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + PortRanges: []RulePortRange{{Start: 20, End: 32000}}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"20-32000"}, + }, + { + name: "adds 22022 when port 22 in port range", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + PortRanges: []RulePortRange{{Start: 20, End: 25}}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"20-25", "22022"}, + }, + { + name: "adds single 22022 once when port 22 in multiple port ranges", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.60.0", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + PortRanges: []RulePortRange{{Start: 20, End: 25}, {Start: 10, End: 100}}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"20-25", "10-100", "22022"}, + }, + { + name: "dev suffix version supports all features", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.50.0-dev", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "dev suffix version supports all features", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "dev", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + { + name: "development suffix version supports all features", + peer: &nbpeer.Peer{ + ID: "peer1", + SSHEnabled: true, + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + Flags: nbpeer.Flags{ServerSSHAllowed: true}, + }, + }, + rule: &PolicyRule{ + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + }, + base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"}, + expectedPorts: []string{"22", "22022"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := expandPortsAndRanges(tt.base, tt.rule, tt.peer) + + var ports []string + for _, fr := range result { + if fr.Port != "" { + ports = append(ports, fr.Port) + } else if fr.PortRange.Start > 0 { + ports = append(ports, fmt.Sprintf("%d-%d", fr.PortRange.Start, fr.PortRange.End)) + } + } + + assert.Equal(t, tt.expectedPorts, ports, "expanded ports should match expected") + }) + } +} + func Test_FilterZoneRecordsForPeers(t *testing.T) { tests := []struct { name string From 131136439764a67187b7fbfadef1215b5cc66780 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:09:22 +0100 Subject: [PATCH 083/120] [client] Increase ssh detection timeout (#4827) --- client/cmd/ssh.go | 17 +++++++++++++---- client/ssh/client/client.go | 9 ++++++--- client/ssh/config/manager.go | 7 +------ client/ssh/detection/detection.go | 18 +++++++++--------- client/ssh/server/jwt_test.go | 6 +++--- client/wasm/cmd/main.go | 24 ++++++++++++++---------- 6 files changed, 46 insertions(+), 35 deletions(-) diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 70c7dbcff..92857c637 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -749,7 +749,9 @@ func sshProxyFn(cmd *cobra.Command, args []string) error { if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile { logOutput = firstLogFile } - if err := util.InitLog(logLevel, logOutput); err != nil { + + proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + if err := util.InitLog(proxyLogLevel, logOutput); err != nil { return fmt.Errorf("init log: %w", err) } @@ -788,7 +790,8 @@ var sshDetectCmd = &cobra.Command{ } func sshDetectFn(cmd *cobra.Command, args []string) error { - if err := util.InitLog(logLevel, "console"); err != nil { + detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + if err := util.InitLog(detectLogLevel, "console"); err != nil { os.Exit(detection.ServerTypeRegular.ExitCode()) } @@ -797,15 +800,21 @@ func sshDetectFn(cmd *cobra.Command, args []string) error { port, err := strconv.Atoi(portStr) if err != nil { + log.Debugf("invalid port %q: %v", portStr, err) os.Exit(detection.ServerTypeRegular.ExitCode()) } - dialer := &net.Dialer{Timeout: detection.Timeout} - serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port) + ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout) + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port) if err != nil { + log.Debugf("SSH server detection failed: %v", err) + cancel() os.Exit(detection.ServerTypeRegular.ExitCode()) } + cancel() os.Exit(serverType.ExitCode()) return nil } diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 882056374..31b80317a 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -343,10 +343,13 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo return nil, fmt.Errorf("parse port %s: %w", portStr, err) } - dialer := &net.Dialer{Timeout: detection.Timeout} - serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port) + detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout) + defer cancel() + + dialer := &net.Dialer{} + serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port) if err != nil { - return nil, fmt.Errorf("SSH server detection failed: %w", err) + return nil, fmt.Errorf("SSH server detection: %w", err) } if !serverType.RequiresJWT() { diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index 03a136de3..cc47fd2d2 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -189,12 +189,7 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { hostLine := strings.Join(deduplicatedPatterns, " ") config := fmt.Sprintf("Host %s\n", hostLine) - - if runtime.GOOS == "windows" { - config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) - } else { - config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath) - } + config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) config += " PreferredAuthentications password,publickey,keyboard-interactive\n" config += " PasswordAuthentication yes\n" config += " PubkeyAuthentication yes\n" diff --git a/client/ssh/detection/detection.go b/client/ssh/detection/detection.go index 487f4665a..f23ea4c37 100644 --- a/client/ssh/detection/detection.go +++ b/client/ssh/detection/detection.go @@ -3,6 +3,7 @@ package detection import ( "bufio" "context" + "fmt" "net" "strconv" "strings" @@ -19,8 +20,8 @@ const ( // JWTRequiredMarker is appended to responses when JWT is required JWTRequiredMarker = "NetBird-JWT-Required" - // Timeout is the timeout for SSH server detection - Timeout = 5 * time.Second + // DefaultTimeout is the default timeout for SSH server detection + DefaultTimeout = 5 * time.Second ) type ServerType string @@ -61,21 +62,20 @@ func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port i conn, err := dialer.DialContext(ctx, "tcp", targetAddr) if err != nil { - log.Debugf("SSH connection failed for detection: %v", err) - return ServerTypeRegular, nil + return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err) } defer conn.Close() - if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil { - log.Debugf("set read deadline: %v", err) - return ServerTypeRegular, nil + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetReadDeadline(deadline); err != nil { + return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err) + } } reader := bufio.NewReader(conn) serverBanner, err := reader.ReadString('\n') if err != nil { - log.Debugf("read SSH banner: %v", err) - return ServerTypeRegular, nil + return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err) } serverBanner = strings.TrimSpace(serverBanner) diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go index e22bdfb06..1f3bac76d 100644 --- a/client/ssh/server/jwt_test.go +++ b/client/ssh/server/jwt_test.go @@ -58,7 +58,7 @@ func TestJWTEnforcement(t *testing.T) { require.NoError(t, err) port, err := strconv.Atoi(portStr) require.NoError(t, err) - dialer := &net.Dialer{Timeout: detection.Timeout} + dialer := &net.Dialer{} serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) if err != nil { t.Logf("Detection failed: %v", err) @@ -93,7 +93,7 @@ func TestJWTEnforcement(t *testing.T) { portNoJWT, err := strconv.Atoi(portStrNoJWT) require.NoError(t, err) - dialer := &net.Dialer{Timeout: detection.Timeout} + dialer := &net.Dialer{} serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT) require.NoError(t, err) assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType) @@ -218,7 +218,7 @@ func TestJWTDetection(t *testing.T) { port, err := strconv.Atoi(portStr) require.NoError(t, err) - dialer := &net.Dialer{Timeout: detection.Timeout} + dialer := &net.Dialer{} serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port) require.NoError(t, err) assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType) diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index 4dc14a1ca..238e272fa 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -19,9 +19,10 @@ import ( ) const ( - clientStartTimeout = 30 * time.Second - clientStopTimeout = 10 * time.Second - defaultLogLevel = "warn" + clientStartTimeout = 30 * time.Second + clientStopTimeout = 10 * time.Second + defaultLogLevel = "warn" + defaultSSHDetectionTimeout = 20 * time.Second ) func main() { @@ -207,11 +208,19 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func { host := args[0].String() port := args[1].Int() + timeoutMs := int(defaultSSHDetectionTimeout.Milliseconds()) + if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() { + timeoutMs = args[2].Int() + if timeoutMs <= 0 { + return js.ValueOf("error: timeout must be positive") + } + } + return createPromise(func(resolve, reject js.Value) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) defer cancel() - serverType, err := detectSSHServerType(ctx, client, host, port) + serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port) if err != nil { reject.Invoke(err.Error()) return @@ -222,11 +231,6 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func { }) } -// detectSSHServerType detects SSH server type using NetBird network connection -func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) { - return sshdetection.DetectSSHServerType(ctx, client, host, port) -} - // createClientObject wraps the NetBird client in a JavaScript object func createClientObject(client *netbird.Client) js.Value { obj := make(map[string]interface{}) From 32146e576d9e817a4fb84a87631521eceed5b503 Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Fri, 21 Nov 2025 09:36:33 -0300 Subject: [PATCH 084/120] [android] allow selection/deselection of network resources on android peers (#4607) --- client/android/client.go | 115 ++++++- client/android/network_domains.go | 56 ++++ client/android/networks.go | 14 +- client/android/peer_notifier.go | 7 + client/android/peer_routes.go | 20 ++ client/android/platform_files.go | 10 + client/android/route_command.go | 67 ++++ client/iface/device/device_android.go | 52 ++- .../iface/device/device_netstack_android.go | 7 + client/iface/device/renewable_tun.go | 309 ++++++++++++++++++ client/iface/device_android.go | 1 + client/iface/iface_create.go | 4 + client/iface/iface_create_android.go | 7 + client/iface/iface_create_darwin.go | 4 + client/internal/connect.go | 2 + client/internal/engine.go | 14 +- client/internal/engine_test.go | 4 + client/internal/iface_common.go | 1 + 18 files changed, 666 insertions(+), 28 deletions(-) create mode 100644 client/android/network_domains.go create mode 100644 client/android/peer_routes.go create mode 100644 client/android/platform_files.go create mode 100644 client/android/route_command.go create mode 100644 client/iface/device/renewable_tun.go diff --git a/client/android/client.go b/client/android/client.go index 86fb1445d..2943702c6 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,10 +4,13 @@ package android import ( "context" + "fmt" "os" "slices" "sync" + "golang.org/x/exp/maps" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/device" @@ -16,10 +19,13 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // ConnectionListener export internal Listener for mobile @@ -62,17 +68,18 @@ type Client struct { deviceName string uiVersion string networkChangeListener listener.NetworkChangeListener + stateFile string connectClient *internal.ConnectClient } // NewClient instantiate a new Client -func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { +func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { execWorkaround(androidSDKVersion) net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) return &Client{ - cfgFile: cfgFile, + cfgFile: platformFiles.ConfigurationFilePath(), deviceName: deviceName, uiVersion: uiVersion, tunAdapter: tunAdapter, @@ -80,6 +87,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi recorder: peer.NewRecorder(""), ctxCancelLock: &sync.Mutex{}, networkChangeListener: networkChangeListener, + stateFile: platformFiles.StateFilePath(), } } @@ -115,7 +123,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -142,7 +150,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile) } // Stop the internal client and free the resources @@ -156,6 +164,19 @@ func (c *Client) Stop() { c.ctxCancel() } +func (c *Client) RenewTun(fd int) error { + if c.connectClient == nil { + return fmt.Errorf("engine not running") + } + + e := c.connectClient.Engine() + if e == nil { + return fmt.Errorf("engine not initialized") + } + + return e.RenewTun(fd) +} + // SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) @@ -177,6 +198,7 @@ func (c *Client) PeersList() *PeerInfoArray { p.IP, p.FQDN, p.ConnStatus.String(), + PeerRoutes{routes: maps.Keys(p.GetRoutes())}, } peerInfos[n] = pi } @@ -201,31 +223,43 @@ func (c *Client) Networks() *NetworkArray { return nil } + routeSelector := routeManager.GetRouteSelector() + if routeSelector == nil { + log.Error("could not get route selector") + return nil + } + networkArray := &NetworkArray{ items: make([]Network, 0), } + resolvedDomains := c.recorder.GetResolvedDomainsStates() + for id, routes := range routeManager.GetClientRoutesWithNetID() { if len(routes) == 0 { continue } r := routes[0] + domains := c.getNetworkDomainsFromRoute(r, resolvedDomains) netStr := r.Network.String() + if r.IsDynamic() { netStr = r.Domains.SafeString() } - peer, err := c.recorder.GetPeer(routes[0].Peer) + routePeer, err := c.recorder.GetPeer(routes[0].Peer) if err != nil { log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) continue } network := Network{ - Name: string(id), - Network: netStr, - Peer: peer.FQDN, - Status: peer.ConnStatus.String(), + Name: string(id), + Network: netStr, + Peer: routePeer.FQDN, + Status: routePeer.ConnStatus.String(), + IsSelected: routeSelector.IsSelected(id), + Domains: domains, } networkArray.Add(network) } @@ -253,6 +287,69 @@ func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } +func (c *Client) toggleRoute(command routeCommand) error { + return command.toggleRoute() +} + +func (c *Client) getRouteManager() (routemanager.Manager, error) { + client := c.connectClient + if client == nil { + return nil, fmt.Errorf("not connected") + } + + engine := client.Engine() + if engine == nil { + return nil, fmt.Errorf("engine is not running") + } + + manager := engine.GetRouteManager() + if manager == nil { + return nil, fmt.Errorf("could not get route manager") + } + + return manager, nil +} + +func (c *Client) SelectRoute(route string) error { + manager, err := c.getRouteManager() + if err != nil { + return err + } + + return c.toggleRoute(selectRouteCommand{route: route, manager: manager}) +} + +func (c *Client) DeselectRoute(route string) error { + manager, err := c.getRouteManager() + if err != nil { + return err + } + + return c.toggleRoute(deselectRouteCommand{route: route, manager: manager}) +} + +// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain +// with its resolved IP addresses from the provided resolvedDomains map. +func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains { + domains := NetworkDomains{} + + for _, d := range route.Domains { + networkDomain := NetworkDomain{ + Address: d.SafeString(), + } + + if info, exists := resolvedDomains[d]; exists { + for _, prefix := range info.Prefixes { + networkDomain.addResolvedIP(prefix.Addr().String()) + } + } + + domains.Add(&networkDomain) + } + + return domains +} + func exportEnvList(list *EnvList) { if list == nil { return diff --git a/client/android/network_domains.go b/client/android/network_domains.go new file mode 100644 index 000000000..a459bdc23 --- /dev/null +++ b/client/android/network_domains.go @@ -0,0 +1,56 @@ +//go:build android + +package android + +import "fmt" + +type ResolvedIPs struct { + resolvedIPs []string +} + +func (r *ResolvedIPs) Add(ipAddress string) { + r.resolvedIPs = append(r.resolvedIPs, ipAddress) +} + +func (r *ResolvedIPs) Get(i int) (string, error) { + if i < 0 || i >= len(r.resolvedIPs) { + return "", fmt.Errorf("%d is out of range", i) + } + return r.resolvedIPs[i], nil +} + +func (r *ResolvedIPs) Size() int { + return len(r.resolvedIPs) +} + +type NetworkDomain struct { + Address string + resolvedIPs ResolvedIPs +} + +func (d *NetworkDomain) addResolvedIP(resolvedIP string) { + d.resolvedIPs.Add(resolvedIP) +} + +func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs { + return &d.resolvedIPs +} + +type NetworkDomains struct { + domains []*NetworkDomain +} + +func (n *NetworkDomains) Add(domain *NetworkDomain) { + n.domains = append(n.domains, domain) +} + +func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) { + if i < 0 || i >= len(n.domains) { + return nil, fmt.Errorf("%d is out of range", i) + } + return n.domains[i], nil +} + +func (n *NetworkDomains) Size() int { + return len(n.domains) +} diff --git a/client/android/networks.go b/client/android/networks.go index aa130420b..3c3a25939 100644 --- a/client/android/networks.go +++ b/client/android/networks.go @@ -3,10 +3,16 @@ package android type Network struct { - Name string - Network string - Peer string - Status string + Name string + Network string + Peer string + Status string + IsSelected bool + Domains NetworkDomains +} + +func (n Network) GetNetworkDomains() *NetworkDomains { + return &n.Domains } type NetworkArray struct { diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go index 1f5564c72..b03947da1 100644 --- a/client/android/peer_notifier.go +++ b/client/android/peer_notifier.go @@ -1,3 +1,5 @@ +//go:build android + package android // PeerInfo describe information about the peers. It designed for the UI usage @@ -5,6 +7,11 @@ type PeerInfo struct { IP string FQDN string ConnStatus string // Todo replace to enum + Routes PeerRoutes +} + +func (p *PeerInfo) GetPeerRoutes() *PeerRoutes { + return &p.Routes } // PeerInfoArray is a wrapper of []PeerInfo diff --git a/client/android/peer_routes.go b/client/android/peer_routes.go new file mode 100644 index 000000000..bb46d609f --- /dev/null +++ b/client/android/peer_routes.go @@ -0,0 +1,20 @@ +//go:build android + +package android + +import "fmt" + +type PeerRoutes struct { + routes []string +} + +func (p *PeerRoutes) Get(i int) (string, error) { + if i < 0 || i >= len(p.routes) { + return "", fmt.Errorf("%d is out of range", i) + } + return p.routes[i], nil +} + +func (p *PeerRoutes) Size() int { + return len(p.routes) +} diff --git a/client/android/platform_files.go b/client/android/platform_files.go new file mode 100644 index 000000000..f0c369750 --- /dev/null +++ b/client/android/platform_files.go @@ -0,0 +1,10 @@ +//go:build android + +package android + +// PlatformFiles groups paths to files used internally by the engine that can't be created/modified +// at their default locations due to android OS restrictions. +type PlatformFiles interface { + ConfigurationFilePath() string + StateFilePath() string +} diff --git a/client/android/route_command.go b/client/android/route_command.go new file mode 100644 index 000000000..b47d5ca6c --- /dev/null +++ b/client/android/route_command.go @@ -0,0 +1,67 @@ +//go:build android + +package android + +import ( + "fmt" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/routemanager" + "github.com/netbirdio/netbird/route" +) + +func executeRouteToggle(id string, manager routemanager.Manager, + operationName string, + routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error { + netID := route.NetID(id) + routes := []route.NetID{netID} + + log.Debugf("%s with id: %s", operationName, id) + + if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { + log.Debugf("error when %s: %s", operationName, err) + return fmt.Errorf("error %s: %w", operationName, err) + } + + manager.TriggerSelection(manager.GetClientRoutes()) + + return nil +} + +type routeCommand interface { + toggleRoute() error +} + +type selectRouteCommand struct { + route string + manager routemanager.Manager +} + +func (s selectRouteCommand) toggleRoute() error { + routeSelector := s.manager.GetRouteSelector() + if routeSelector == nil { + return fmt.Errorf("no route selector available") + } + + routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error { + return routeSelector.SelectRoutes(routes, true, allRoutes) + } + + return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation) +} + +type deselectRouteCommand struct { + route string + manager routemanager.Manager +} + +func (d deselectRouteCommand) toggleRoute() error { + routeSelector := d.manager.GetRouteSelector() + if routeSelector == nil { + return fmt.Errorf("no route selector available") + } + + return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes) +} diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 48346fc0f..198343fbd 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -3,6 +3,7 @@ package device import ( + "fmt" "strings" log "github.com/sirupsen/logrus" @@ -19,11 +20,12 @@ import ( // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform type WGTunDevice struct { - address wgaddr.Address - port int - key string - mtu uint16 - iceBind *bind.ICEBind + address wgaddr.Address + port int + key string + mtu uint16 + iceBind *bind.ICEBind + // todo: review if we can eliminate the TunAdapter tunAdapter TunAdapter disableDNS bool @@ -32,17 +34,19 @@ type WGTunDevice struct { filteredDevice *FilteredDevice udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer + renewableTun *RenewableTUN } func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { return &WGTunDevice{ - address: address, - port: port, - key: key, - mtu: mtu, - iceBind: iceBind, - tunAdapter: tunAdapter, - disableDNS: disableDNS, + address: address, + port: port, + key: key, + mtu: mtu, + iceBind: iceBind, + tunAdapter: tunAdapter, + disableDNS: disableDNS, + renewableTun: NewRenewableTUN(), } } @@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string return nil, err } - tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd) + unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd) if err != nil { _ = unix.Close(fd) log.Errorf("failed to create Android interface: %s", err) return nil, err } + + t.renewableTun.AddDevice(unmonitoredTUN) + t.name = name - t.filteredDevice = newDeviceFilter(tunDevice) + t.filteredDevice = newDeviceFilter(t.renewableTun) log.Debugf("attaching to interface %v", name) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] ")) @@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return udpMux, nil } +func (t *WGTunDevice) RenewTun(fd int) error { + if t.device == nil { + return fmt.Errorf("device not initialized") + } + + unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd) + if err != nil { + _ = unix.Close(fd) + log.Errorf("failed to renew Android interface: %s", err) + return err + } + + t.renewableTun.AddDevice(unmonitoredTUN) + + return nil +} + func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error { // todo implement return nil diff --git a/client/iface/device/device_netstack_android.go b/client/iface/device/device_netstack_android.go index 45ae8ba7d..f1a77d40a 100644 --- a/client/iface/device/device_netstack_android.go +++ b/client/iface/device/device_netstack_android.go @@ -2,6 +2,13 @@ package device +import "fmt" + func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { return t.create() } + +func (t *TunNetstackDevice) RenewTun(fd int) error { + // Doesn't make sense in Android for Netstack. + return fmt.Errorf("this function has not been implemented in Netstack for Android") +} diff --git a/client/iface/device/renewable_tun.go b/client/iface/device/renewable_tun.go new file mode 100644 index 000000000..a501eebbb --- /dev/null +++ b/client/iface/device/renewable_tun.go @@ -0,0 +1,309 @@ +//go:build android + +package device + +import ( + "io" + "os" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun" +) + +// closeAwareDevice wraps a tun.Device along with a flag +// indicating whether its Close method was called. +// +// It also redirects tun.Device's Events() to a separate goroutine +// and closes it when Close is called. +// +// The WaitGroup and CloseOnce fields are used to ensure that the +// goroutine is awaited and closed only once. +type closeAwareDevice struct { + isClosed atomic.Bool + tun.Device + closeEventCh chan struct{} + wg sync.WaitGroup + closeOnce sync.Once +} + +func newClosableDevice(tunDevice tun.Device) *closeAwareDevice { + return &closeAwareDevice{ + Device: tunDevice, + isClosed: atomic.Bool{}, + closeEventCh: make(chan struct{}), + } +} + +// redirectEvents redirects the Events() method of the underlying tun.Device +// to the given channel (RenewableTUN's events channel). +func (c *closeAwareDevice) redirectEvents(out chan tun.Event) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + select { + case ev, ok := <-c.Device.Events(): + if !ok { + return + } + + if ev == tun.EventDown { + continue + } + + select { + case out <- ev: + case <-c.closeEventCh: + return + } + case <-c.closeEventCh: + return + } + } + }() +} + +// Close calls the underlying Device's Close method +// after setting isClosed to true. +func (c *closeAwareDevice) Close() (err error) { + c.closeOnce.Do(func() { + c.isClosed.Store(true) + close(c.closeEventCh) + err = c.Device.Close() + c.wg.Wait() + }) + + return err +} + +func (c *closeAwareDevice) IsClosed() bool { + return c.isClosed.Load() +} + +type RenewableTUN struct { + devices []*closeAwareDevice + mu sync.Mutex + cond *sync.Cond + events chan tun.Event + closed atomic.Bool +} + +func NewRenewableTUN() *RenewableTUN { + r := &RenewableTUN{ + devices: make([]*closeAwareDevice, 0), + mu: sync.Mutex{}, + events: make(chan tun.Event, 16), + } + r.cond = sync.NewCond(&r.mu) + return r +} + +func (r *RenewableTUN) File() *os.File { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return nil + } + continue + } + + file := dev.File() + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return file + } +} + +// Read reads from an underlying tun.Device kept in the r.devices slice. +// If no device is available, it waits for one to be added via AddDevice(). +// +// On error, it retries reading from the newest device instead of returning the error +// if the device is closed; if not, it propagates the error. +func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + for { + dev := r.peekLast() + if dev == nil { + // wait until AddDevice() signals a new device via cond.Broadcast() + if !r.waitForDevice() { // returns false if the renewable TUN itself is closed + return 0, io.EOF + } + continue + } + + n, err = dev.Read(bufs, sizes, offset) + if err == nil { + return n, nil + } + + // swap in progress; retry on the newest instead of returning the error + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return n, err // propagate non-swap error + } +} + +// Write writes to underlying tun.Device kept in the r.devices slice. +// If no device is available, it waits for one to be added via AddDevice(). +// +// On error, it retries writing to the newest device instead of returning the error +// if the device is closed; if not, it propagates the error. +func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return 0, io.EOF + } + continue + } + + n, err := dev.Write(bufs, offset) + if err == nil { + return n, nil + } + + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + + return n, err + } +} + +func (r *RenewableTUN) MTU() (int, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return 0, io.EOF + } + continue + } + mtu, err := dev.MTU() + if err == nil { + return mtu, nil + } + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return 0, err + } +} + +func (r *RenewableTUN) Name() (string, error) { + for { + dev := r.peekLast() + if dev == nil { + if !r.waitForDevice() { + return "", io.EOF + } + continue + } + name, err := dev.Name() + if err == nil { + return name, nil + } + if dev.IsClosed() { + time.Sleep(1 * time.Millisecond) + continue + } + return "", err + } +} + +// Events returns a channel that is fed events from the underlying tun.Device's events channel +// once it is added. +func (r *RenewableTUN) Events() <-chan tun.Event { + return r.events +} + +func (r *RenewableTUN) Close() error { + // Attempts to set the RenewableTUN closed flag to true. + // If it's already true, returns immediately. + if !r.closed.CompareAndSwap(false, true) { + return nil // already closed: idempotent + } + r.mu.Lock() + devices := r.devices + r.devices = nil + r.cond.Broadcast() + r.mu.Unlock() + + var lastErr error + + log.Debugf("closing %d devices", len(devices)) + for _, device := range devices { + if err := device.Close(); err != nil { + log.Debugf("error closing a device: %v", err) + lastErr = err + } + } + + close(r.events) + return lastErr +} + +func (r *RenewableTUN) BatchSize() int { + return 1 +} + +func (r *RenewableTUN) AddDevice(device tun.Device) { + r.mu.Lock() + if r.closed.Load() { + r.mu.Unlock() + _ = device.Close() + return + } + + var toClose *closeAwareDevice + if len(r.devices) > 0 { + toClose = r.devices[len(r.devices)-1] + } + + cad := newClosableDevice(device) + cad.redirectEvents(r.events) + + r.devices = []*closeAwareDevice{cad} + r.cond.Broadcast() + + r.mu.Unlock() + + if toClose != nil { + if err := toClose.Close(); err != nil { + log.Debugf("error closing last device: %v", err) + } + } +} + +func (r *RenewableTUN) waitForDevice() bool { + r.mu.Lock() + defer r.mu.Unlock() + + for len(r.devices) == 0 && !r.closed.Load() { + r.cond.Wait() + } + return !r.closed.Load() +} + +func (r *RenewableTUN) peekLast() *closeAwareDevice { + r.mu.Lock() + defer r.mu.Unlock() + + if len(r.devices) == 0 { + return nil + } + + return r.devices[len(r.devices)-1] +} diff --git a/client/iface/device_android.go b/client/iface/device_android.go index cdfcea48d..3899bf426 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,5 +21,6 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + RenewTun(fd int) error GetICEBind() device.EndpointManager } diff --git a/client/iface/iface_create.go b/client/iface/iface_create.go index 5e17c6d41..13ae9393c 100644 --- a/client/iface/iface_create.go +++ b/client/iface/iface_create.go @@ -24,3 +24,7 @@ func (w *WGIface) Create() error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error { return fmt.Errorf("this function has not implemented on non mobile") } + +func (w *WGIface) RenewTun(fd int) error { + return fmt.Errorf("this function has not been implemented on non-android") +} diff --git a/client/iface/iface_create_android.go b/client/iface/iface_create_android.go index 373a9c95a..d2d9eb70e 100644 --- a/client/iface/iface_create_android.go +++ b/client/iface/iface_create_android.go @@ -6,6 +6,7 @@ import ( // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. +// todo: review does this function really necessary or can we merge it with iOS func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { w.mu.Lock() defer w.mu.Unlock() @@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s func (w *WGIface) Create() error { return fmt.Errorf("this function has not implemented on this platform") } + +func (w *WGIface) RenewTun(fd int) error { + w.mu.Lock() + defer w.mu.Unlock() + return w.tun.RenewTun(fd) +} diff --git a/client/iface/iface_create_darwin.go b/client/iface/iface_create_darwin.go index 1d91bce54..0b7cd36ef 100644 --- a/client/iface/iface_create_darwin.go +++ b/client/iface/iface_create_darwin.go @@ -39,3 +39,7 @@ func (w *WGIface) Create() error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error { return fmt.Errorf("this function has not implemented on this platform") } + +func (w *WGIface) RenewTun(fd int) error { + return fmt.Errorf("this function has not been implemented on this platform") +} diff --git a/client/internal/connect.go b/client/internal/connect.go index 6ad5f264b..5a5f4f63c 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -74,6 +74,7 @@ func (c *ConnectClient) RunOnAndroid( networkChangeListener listener.NetworkChangeListener, dnsAddresses []netip.AddrPort, dnsReadyListener dns.ReadyListener, + stateFilePath string, ) error { // in case of non Android os these variables will be nil mobileDependency := MobileDependency{ @@ -82,6 +83,7 @@ func (c *ConnectClient) RunOnAndroid( NetworkChangeListener: networkChangeListener, HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, + StateFilePath: stateFilePath, } return c.run(mobileDependency, nil) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 94c948398..76265dd77 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -255,7 +255,7 @@ func NewEngine( sm := profilemanager.NewServiceManager("") path := sm.GetStatePath() - if runtime.GOOS == "ios" { + if runtime.GOOS == "ios" || runtime.GOOS == "android" { if !fileExists(mobileDep.StateFilePath) { err := createFile(mobileDep.StateFilePath) if err != nil { @@ -1831,6 +1831,18 @@ func (e *Engine) GetWgAddr() netip.Addr { return e.wgInterface.Address().IP } +func (e *Engine) RenewTun(fd int) error { + e.syncMsgMux.Lock() + wgInterface := e.wgInterface + e.syncMsgMux.Unlock() + + if wgInterface == nil { + return fmt.Errorf("wireguard interface not initialized") + } + + return wgInterface.RenewTun(fd) +} + // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag func (e *Engine) updateDNSForwarder( enabled bool, diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d6ab7391e..9252ce13e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -110,6 +110,10 @@ type MockWGIface struct { LastActivitiesFunc func() map[string]monotime.Time } +func (m *MockWGIface) RenewTun(_ int) error { + return nil +} + func (m *MockWGIface) RemoveEndpointAddress(_ string) error { return nil } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 98fe01912..90b06cbd1 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -20,6 +20,7 @@ import ( type wgIfaceBase interface { Create() error CreateOnAndroid(routeRange []string, ip string, domains []string) error + RenewTun(fd int) error IsUserspaceBind() bool Name() string Address() wgaddr.Address From 7fb1a2fe312f8549095465469e64a4af3b155a59 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Sat, 22 Nov 2025 01:23:33 +0100 Subject: [PATCH 085/120] [management] removed TestBufferUpdateAccountPeers because it was incorrect (#4839) --- .../network_map/controller/controller_test.go | 135 ------------------ 1 file changed, 135 deletions(-) diff --git a/management/internals/controllers/network_map/controller/controller_test.go b/management/internals/controllers/network_map/controller/controller_test.go index baaffe677..90e7b6e18 100644 --- a/management/internals/controllers/network_map/controller/controller_test.go +++ b/management/internals/controllers/network_map/controller/controller_test.go @@ -1,16 +1,9 @@ package controller import ( - "context" - "sync" - "sync/atomic" "testing" - "time" - - "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/internals/controllers/network_map" - "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -114,131 +107,3 @@ func TestComputeForwarderPort(t *testing.T) { t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result) } } - -func TestBufferUpdateAccountPeers(t *testing.T) { - const ( - peersCount = 1000 - updateAccountInterval = 50 * time.Millisecond - ) - - var ( - deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 - uapLastRun, dpLastRun atomic.Int64 - - totalNewRuns, totalOldRuns int - ) - - uap := func(ctx context.Context, accountID string) { - updatePeersDeleted.Store(deletedPeers.Load()) - updatePeersRuns.Add(1) - uapLastRun.Store(time.Now().UnixMilli()) - time.Sleep(100 * time.Millisecond) - } - - t.Run("new approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) - b := mu.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - uap(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - b.next = time.AfterFunc(updateAccountInterval, func() { - uap(ctx, accountID) - }) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalNewRuns = int(updatePeersRuns.Load()) - }) - - t.Run("old approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) - b := mu.(*sync.Mutex) - - if !b.TryLock() { - return - } - - go func() { - time.Sleep(updateAccountInterval) - b.Unlock() - uap(ctx, accountID) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalOldRuns = int(updatePeersRuns.Load()) - }) - assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) - t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) -} From 290fe2d8b91483073f1c9dba439f464a910cfc61 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 22 Nov 2025 10:10:18 +0100 Subject: [PATCH 086/120] [client/management/signal/relay] Update go.mod to use Go 1.24.10 and upgrade x/crypto dependencies (#4828) Upgrade Go toolchain and golang.org/x/* deps to 1.24.10, standardize GitHub Actions to derive Go version from go.mod and adjust checkout ordering, raise WASM size limit to 55 MB, update FreeBSD tarball and gomobile refs, fix a few format-string/logging calls, treat usernames ending with $ as system accounts, and add Windows tests. --- .github/workflows/golang-test-darwin.yml | 7 +- .github/workflows/golang-test-freebsd.yml | 2 +- .github/workflows/golang-test-linux.yml | 71 +++++------ .github/workflows/golang-test-windows.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- .github/workflows/mobile-build-validation.yml | 8 +- .github/workflows/release.yml | 8 +- .../workflows/test-infrastructure-files.yml | 8 +- .github/workflows/wasm-build-validation.yml | 8 +- client/internal/dns/upstream.go | 2 +- client/ssh/testutil/user_helpers.go | 3 +- client/ssh/testutil/user_helpers_test.go | 115 ++++++++++++++++++ go.mod | 22 ++-- go.sum | 40 +++--- .../policies/posture_checks_handler_test.go | 2 +- management/server/posture_checks.go | 2 +- 16 files changed, 210 insertions(+), 92 deletions(-) create mode 100644 client/ssh/testutil/user_helpers_test.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 4571ce753..9c4c35d21 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -15,13 +15,14 @@ jobs: name: "Client / Unit" runs-on: macos-latest steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - name: Cache Go modules uses: actions/cache@v4 diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index cdd0910a4..b03313bbd 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -25,7 +25,7 @@ jobs: release: "14.2" prepare: | pkg install -y curl pkgconf xorg - GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz" + GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" GO_URL="https://go.dev/dl/$GO_TARBALL" curl -vLO "$GO_URL" tar -C /usr/local -vxzf "$GO_TARBALL" diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ba36c013b..c09bfab39 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -30,7 +30,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - name: Get Go environment @@ -106,15 +106,15 @@ jobs: arch: [ '386','amd64' ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -151,15 +151,15 @@ jobs: needs: [ build-cache ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment id: go-env run: | @@ -200,7 +200,7 @@ jobs: -e GOCACHE=${CONTAINER_GOCACHE} \ -e GOMODCACHE=${CONTAINER_GOMODCACHE} \ -e CONTAINER=${CONTAINER} \ - golang:1.23-alpine \ + golang:1.24-alpine \ sh -c ' \ apk update; apk add --no-cache \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ @@ -220,15 +220,15 @@ jobs: raceFlag: "-race" runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386 @@ -270,15 +270,15 @@ jobs: arch: [ '386','amd64' ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386 @@ -321,15 +321,15 @@ jobs: store: [ 'sqlite', 'postgres', 'mysql' ] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -408,15 +408,16 @@ jobs: -v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \ -p 9090:9090 \ prom/prometheus - - name: Install Go - uses: actions/setup-go@v5 - with: - go-version: "1.23.x" - cache: false - name: Checkout code uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version-file: "go.mod" + cache: false + - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -497,15 +498,15 @@ jobs: -p 9090:9090 \ prom/prometheus + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV @@ -561,15 +562,15 @@ jobs: store: [ 'sqlite', 'postgres'] runs-on: ubuntu-22.04 steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - - name: Checkout code - uses: actions/checkout@v4 - - name: Get Go environment run: | echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 2083c0721..43357c45f 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -24,7 +24,7 @@ jobs: uses: actions/setup-go@v5 id: go with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - name: Get Go environment diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 2845b05a5..c524f6f6b 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -46,7 +46,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml index c7d43695b..8325fbf2d 100644 --- a/.github/workflows/mobile-build-validation.yml +++ b/.github/workflows/mobile-build-validation.yml @@ -20,7 +20,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Setup Android SDK uses: android-actions/setup-android@v3 with: @@ -39,7 +39,7 @@ jobs: - name: Setup NDK run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620" - name: install gomobile - run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab - name: gomobile init run: gomobile init - name: build android netbird lib @@ -56,9 +56,9 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: install gomobile - run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed + run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab - name: gomobile init run: gomobile init - name: build iOS netbird lib diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e9741f541..a9bc1b979 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ concurrency: jobs: release: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest-m env: flags: "" steps: @@ -40,7 +40,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version-file: "go.mod" cache: false - name: Cache Go modules uses: actions/cache@v4 @@ -136,7 +136,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version-file: "go.mod" cache: false - name: Cache Go modules uses: actions/cache@v4 @@ -200,7 +200,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version-file: "go.mod" cache: false - name: Cache Go modules uses: actions/cache@v4 diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index 3855baba2..f4513e0e1 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -67,10 +67,13 @@ jobs: - name: Install curl run: sudo apt-get install -y curl + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Cache Go modules uses: actions/cache@v4 @@ -80,9 +83,6 @@ jobs: restore-keys: | ${{ runner.os }}-go- - - name: Checkout code - uses: actions/checkout@v4 - - name: Setup MySQL privileges if: matrix.store == 'mysql' run: | diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml index e4ac799bc..4100e16dd 100644 --- a/.github/workflows/wasm-build-validation.yml +++ b/.github/workflows/wasm-build-validation.yml @@ -20,7 +20,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: Install golangci-lint @@ -45,7 +45,7 @@ jobs: - name: Install Go uses: actions/setup-go@v5 with: - go-version: "1.23.x" + go-version-file: "go.mod" - name: Build Wasm client run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd env: @@ -60,8 +60,8 @@ jobs: echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" - if [ ${SIZE} -gt 52428800 ]; then - echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!" + if [ ${SIZE} -gt 57671680 ]; then + echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!" exit 1 fi diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index c19e0acb5..2a92fd6d8 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add timeoutMsg += " " + peerInfo } timeoutMsg += fmt.Sprintf(" - error: %v", err) - logger.Warnf(timeoutMsg) + logger.Warn(timeoutMsg) } func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { diff --git a/client/ssh/testutil/user_helpers.go b/client/ssh/testutil/user_helpers.go index 0c1222078..8960d8dd0 100644 --- a/client/ssh/testutil/user_helpers.go +++ b/client/ssh/testutil/user_helpers.go @@ -72,7 +72,8 @@ func IsSystemAccount(username string) bool { return true } } - return false + + return strings.HasSuffix(username, "$") } // RegisterTestUserCleanup registers a test user for cleanup diff --git a/client/ssh/testutil/user_helpers_test.go b/client/ssh/testutil/user_helpers_test.go new file mode 100644 index 000000000..db2f5f06d --- /dev/null +++ b/client/ssh/testutil/user_helpers_test.go @@ -0,0 +1,115 @@ +package testutil + +import ( + "os/user" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestUserCurrentBehavior validates user.Current() behavior on Windows. +// When running as SYSTEM on a domain-joined machine, user.Current() returns: +// - Username: Computer account name (e.g., "DOMAIN\MACHINE$") +// - SID: SYSTEM SID (S-1-5-18) +func TestUserCurrentBehavior(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows-specific test") + } + + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + t.Logf("Current user - Username: %s, SID: %s", currentUser.Username, currentUser.Uid) + + // When running as SYSTEM, validate expected behavior + if currentUser.Uid == "S-1-5-18" { + t.Run("SYSTEM_account_behavior", func(t *testing.T) { + // SID must be S-1-5-18 for SYSTEM + require.Equal(t, "S-1-5-18", currentUser.Uid, + "SYSTEM account must have SID S-1-5-18") + + // Username can be either "NT AUTHORITY\SYSTEM" (standalone) + // or "DOMAIN\MACHINE$" (domain-joined) + username := currentUser.Username + isNTAuthority := strings.Contains(strings.ToUpper(username), "NT AUTHORITY") + isComputerAccount := strings.HasSuffix(username, "$") + + assert.True(t, isNTAuthority || isComputerAccount, + "Username should be either 'NT AUTHORITY\\SYSTEM' or computer account (ending with $), got: %s", + username) + + if isComputerAccount { + t.Logf("SYSTEM as computer account: %s", username) + } else if isNTAuthority { + t.Logf("SYSTEM as NT AUTHORITY\\SYSTEM") + } + }) + } + + // Validate that IsSystemAccount correctly identifies system accounts + t.Run("IsSystemAccount_validation", func(t *testing.T) { + // Test with current user if it's a system account + if currentUser.Uid == "S-1-5-18" || // SYSTEM + currentUser.Uid == "S-1-5-19" || // LOCAL SERVICE + currentUser.Uid == "S-1-5-20" { // NETWORK SERVICE + + result := IsSystemAccount(currentUser.Username) + assert.True(t, result, + "IsSystemAccount should recognize system account: %s (SID: %s)", + currentUser.Username, currentUser.Uid) + } + + // Test explicit cases + testCases := []struct { + username string + expected bool + reason string + }{ + {"NT AUTHORITY\\SYSTEM", true, "NT AUTHORITY\\SYSTEM"}, + {"system", true, "system"}, + {"SYSTEM", true, "SYSTEM (case insensitive)"}, + {"NT AUTHORITY\\LOCAL SERVICE", true, "LOCAL SERVICE"}, + {"NT AUTHORITY\\NETWORK SERVICE", true, "NETWORK SERVICE"}, + {"DOMAIN\\MACHINE$", true, "computer account (ends with $)"}, + {"WORKGROUP\\WIN2K19-C2$", true, "computer account (ends with $)"}, + {"Administrator", false, "Administrator is not a system account"}, + {"alice", false, "regular user"}, + {"DOMAIN\\alice", false, "domain user"}, + } + + for _, tc := range testCases { + t.Run(tc.username, func(t *testing.T) { + result := IsSystemAccount(tc.username) + assert.Equal(t, tc.expected, result, + "IsSystemAccount(%q) should be %v because: %s", + tc.username, tc.expected, tc.reason) + }) + } + }) +} + +// TestComputerAccountDetection validates computer account detection. +func TestComputerAccountDetection(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows-specific test") + } + + computerAccounts := []string{ + "MACHINE$", + "WIN2K19-C2$", + "DOMAIN\\MACHINE$", + "WORKGROUP\\SERVER$", + "server.domain.com$", + } + + for _, account := range computerAccounts { + t.Run(account, func(t *testing.T) { + result := IsSystemAccount(account) + assert.True(t, result, + "Computer account %q should be recognized as system account", account) + }) + } +} diff --git a/go.mod b/go.mod index 45a36190d..a72d09dc3 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/netbirdio/netbird -go 1.23.1 +go 1.24.10 require ( cunicu.li/go-rosenpass v0.4.0 @@ -17,8 +17,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.41.0 - golang.org/x/sys v0.35.0 + golang.org/x/crypto v0.45.0 + golang.org/x/sys v0.38.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -105,12 +105,12 @@ require ( go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 - golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/mod v0.26.0 - golang.org/x/net v0.42.0 + golang.org/x/mobile v0.0.0-20251113184115-a159579294ab + golang.org/x/mod v0.30.0 + golang.org/x/net v0.47.0 golang.org/x/oauth2 v0.30.0 - golang.org/x/sync v0.16.0 - golang.org/x/term v0.34.0 + golang.org/x/sync v0.18.0 + golang.org/x/term v0.37.0 golang.org/x/time v0.12.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 @@ -251,9 +251,9 @@ require ( go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/image v0.24.0 // indirect - golang.org/x/text v0.28.0 // indirect - golang.org/x/tools v0.35.0 // indirect + golang.org/x/image v0.33.0 // indirect + golang.org/x/text v0.31.0 // indirect + golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect diff --git a/go.sum b/go.sum index ec68a8f59..011a5f199 100644 --- a/go.sum +++ b/go.sum @@ -600,19 +600,19 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= -golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ= -golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8= +golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= +golang.org/x/image v0.33.0/go.mod h1:DD3OsTYT9chzuzTQt+zMcOlBHgfoKQb1gry8p76Y1sc= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a h1:sYbmY3FwUWCBTodZL1S3JUuOvaW6kM2o+clDzzDNBWg= -golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a/go.mod h1:Ede7gF0KGoHlj822RtphAHK1jLdrcuRBZg0sF1Q+SPc= +golang.org/x/mobile v0.0.0-20251113184115-a159579294ab h1:Iqyc+2zr7aGyLuEadIm0KRJP0Wwt+fhlXLa51Fxf1+Q= +golang.org/x/mobile v0.0.0-20251113184115-a159579294ab/go.mod h1:Eq3Nh/5pFSWug2ohiudJ1iyU59SO78QFuh4qTTN++I0= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -622,8 +622,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -647,8 +647,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= @@ -665,8 +665,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -703,8 +703,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -717,8 +717,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= -golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -730,8 +730,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -749,8 +749,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index 8c60d6fe8..35198da32 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint } return postureChecks, nil diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index ac8ea35de..9a743eb8c 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -158,7 +158,7 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.St // validatePostureChecks validates the posture checks. func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error { if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return status.Errorf(status.InvalidArgument, "%s", err.Error()) //nolint } // If the posture check already has an ID, verify its existence in the store. From 131d7a36943892ee33631158fb6cd5e878f5a4fd Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 22 Nov 2025 18:57:07 +0100 Subject: [PATCH 087/120] [client] Make mss clamping optional for nftables (#4843) --- client/firewall/nftables/router_linux.go | 35 +++++++++--------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 6192c92aa..e4debc179 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -91,11 +91,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou var err error r.filterTable, err = r.loadFilterTable() if err != nil { - if errors.Is(err, errFilterTableNotFound) { - log.Warnf("table 'filter' not found for forward rules") - } else { - return nil, fmt.Errorf("load filter table: %w", err) - } + log.Warnf("failed to load filter table, skipping accept rules: %v", err) } return r, nil @@ -175,7 +171,7 @@ func (r *router) removeNatPreroutingRules() error { func (r *router) loadFilterTable() (*nftables.Table, error) { tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { - return nil, fmt.Errorf("unable to list tables: %v", err) + return nil, fmt.Errorf("list tables: %w", err) } for _, table := range tables { @@ -193,8 +189,6 @@ func (r *router) createContainers() error { Table: r.workTable, }) - insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) - prio := *nftables.ChainPriorityNATSource - 1 r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, @@ -236,9 +230,12 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeFilter, }) - // Add the single NAT rule that matches on mark - if err := r.addPostroutingRules(); err != nil { - return fmt.Errorf("add single nat rule: %v", err) + insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + + r.addPostroutingRules() + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("initialize tables: %v", err) } if err := r.addMSSClampingRules(); err != nil { @@ -250,11 +247,7 @@ func (r *router) createContainers() error { } if err := r.refreshRulesMap(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - if err := r.conn.Flush(); err != nil { - return fmt.Errorf("initialize tables: %v", err) + log.Errorf("failed to refresh rules: %s", err) } return nil @@ -695,7 +688,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { } // addPostroutingRules adds the masquerade rules -func (r *router) addPostroutingRules() error { +func (r *router) addPostroutingRules() { // First masquerade rule for traffic coming in from WireGuard interface exprs := []expr.Any{ // Match on the first fwmark @@ -761,8 +754,6 @@ func (r *router) addPostroutingRules() error { Chain: r.chains[chainNameRoutingNat], Exprs: exprs2, }) - - return nil } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. @@ -839,7 +830,7 @@ func (r *router) addMSSClampingRules() error { Exprs: exprsOut, }) - return nil + return r.conn.Flush() } // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls @@ -1068,7 +1059,7 @@ func (r *router) acceptFilterRulesNftables() error { } r.conn.InsertRule(inputRule) - return nil + return r.conn.Flush() } func (r *router) removeAcceptFilterRules() error { @@ -1196,7 +1187,7 @@ func (r *router) refreshRulesMap() error { for _, chain := range r.chains { rules, err := r.conn.GetRules(chain.Table, chain) if err != nil { - return fmt.Errorf(" unable to list rules: %v", err) + return fmt.Errorf("list rules: %w", err) } for _, rule := range rules { if len(rule.UserData) > 0 { From ba2e9b6d88e57a760341ac83473ede264ea32204 Mon Sep 17 00:00:00 2001 From: Aziz Hasanain Date: Mon, 24 Nov 2025 14:12:51 +0300 Subject: [PATCH 088/120] [management] Fix SSH JWT issuer derivation for IDPs with path components (#4844) --- management/internals/shared/grpc/conversion.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 148db4e37..2b15fe4b8 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -369,7 +369,7 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi } issuer := strings.TrimSpace(config.AuthIssuer) - if issuer == "" || deviceFlowConfig != nil { + if issuer == "" && deviceFlowConfig != nil { if d := deriveIssuerFromTokenEndpoint(deviceFlowConfig.ProviderConfig.TokenEndpoint); d != "" { issuer = d } From 20973063d865414aa551bd3daa73078a66a1ea9b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 24 Nov 2025 17:50:08 +0100 Subject: [PATCH 089/120] [client] Support disable search domain for custom zones (#4826) Two new boolean flags, SearchDomainDisabled and SkipPTRProcess, are added to CustomZone and its protobuf; they are propagated through the engine to DNS host logic. Host matching now uses SearchDomainDisabled directly, and PTR collection skips zones with SkipPTRProcess; reverse zones are initialized with SearchDomainDisabled: true. --- client/internal/dns.go | 8 +- client/internal/dns/host.go | 8 +- client/internal/engine.go | 4 +- dns/dns.go | 4 + shared/management/proto/management.pb.go | 312 ++++++++++++----------- shared/management/proto/management.proto | 2 + 6 files changed, 183 insertions(+), 155 deletions(-) diff --git a/client/internal/dns.go b/client/internal/dns.go index 5e604bec5..3c68e4d00 100644 --- a/client/internal/dns.go +++ b/client/internal/dns.go @@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple var records []nbdns.SimpleRecord for _, zone := range config.CustomZones { + if zone.SkipPTRProcess { + continue + } for _, record := range zone.Records { if record.Type != int(dns.TypeA) { continue @@ -106,8 +109,9 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) { records := collectPTRRecords(config, network) reverseZone := nbdns.CustomZone{ - Domain: zoneName, - Records: records, + Domain: zoneName, + Records: records, + SearchDomainDisabled: true, } config.CustomZones = append(config.CustomZones, reverseZone) diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index fa474afde..f7dc46a6b 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -11,11 +11,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) -const ( - ipv4ReverseZone = ".in-addr.arpa." - ipv6ReverseZone = ".ip6.arpa." -) - type hostManager interface { applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNS() error @@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H } for _, customZone := range dnsConfig.CustomZones { - matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) config.Domains = append(config.Domains, DomainConfig{ Domain: strings.ToLower(dns.Fqdn(customZone.Domain)), - MatchOnly: matchOnly, + MatchOnly: customZone.SearchDomainDisabled, }) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 76265dd77..0ff1006cd 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1207,7 +1207,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns for _, zone := range protoDNSConfig.GetCustomZones() { dnsZone := nbdns.CustomZone{ - Domain: zone.GetDomain(), + Domain: zone.GetDomain(), + SearchDomainDisabled: zone.GetSearchDomainDisabled(), + SkipPTRProcess: zone.GetSkipPTRProcess(), } for _, record := range zone.Records { dnsRecord := nbdns.SimpleRecord{ diff --git a/dns/dns.go b/dns/dns.go index cf089d4ed..aa0e16eb1 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -45,6 +45,10 @@ type CustomZone struct { Domain string // Records custom zone records Records []SimpleRecord + // SearchDomainDisabled indicates whether to add match domains to a search domains list or not + SearchDomainDisabled bool + // SkipPTRProcess indicates whether a client should process PTR records from custom zones + SkipPTRProcess bool } // SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 45211b7f4..2e4cf2644 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v6.33.0 +// protoc v6.32.1 // source: management.proto package proto @@ -2682,8 +2682,10 @@ type CustomZone struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"` - Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"` + Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"` + Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"` + SearchDomainDisabled bool `protobuf:"varint,3,opt,name=SearchDomainDisabled,proto3" json:"SearchDomainDisabled,omitempty"` + SkipPTRProcess bool `protobuf:"varint,4,opt,name=SkipPTRProcess,proto3" json:"SkipPTRProcess,omitempty"` } func (x *CustomZone) Reset() { @@ -2732,6 +2734,20 @@ func (x *CustomZone) GetRecords() []*SimpleRecord { return nil } +func (x *CustomZone) GetSearchDomainDisabled() bool { + if x != nil { + return x.SearchDomainDisabled + } + return false +} + +func (x *CustomZone) GetSkipPTRProcess() bool { + if x != nil { + return x.SkipPTRProcess + } + return false +} + // SimpleRecord represents a dns.SimpleRecord type SimpleRecord struct { state protoimpl.MessageState @@ -3893,157 +3909,163 @@ var file_management_proto_rawDesc = []byte{ 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, - 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, - 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, - 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, - 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, - 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, - 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, - 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, - 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, - 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, - 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, - 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, - 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x6f, 0x72, 0x74, 0x22, 0xb4, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, + 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, + 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, + 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x53, 0x6b, 0x69, 0x70, 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, + 0x63, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x53, 0x6b, 0x69, 0x70, + 0x50, 0x54, 0x52, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, + 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, + 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, + 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, + 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, + 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, + 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, + 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, + 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, + 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, + 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, + 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, + 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, + 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, + 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, + 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, + 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, + 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, + 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, + 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, + 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, + 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, + 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, + 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, + 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, + 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, + 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, - 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, - 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, - 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, - 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, - 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, - 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, - 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, - 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, - 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, - 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, - 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, - 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, - 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, - 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, - 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, - 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, - 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, - 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, - 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, - 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, - 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, - 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, - 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, - 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, + 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, + 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, + 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, + 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, + 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, + 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, - 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, - 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, - 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, - 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, - 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, - 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, - 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, - 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, - 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, - 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, - 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, - 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, - 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, - 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, - 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, + 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, + 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, + 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, + 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, + 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, + 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, + 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, + 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, + 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, + 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, + 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, + 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, + 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, - 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, - 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, + 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, + 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, + 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, + 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, + 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, + 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, + 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 8a297e75e..dc60b026d 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -433,6 +433,8 @@ message DNSConfig { message CustomZone { string Domain = 1; repeated SimpleRecord Records = 2; + bool SearchDomainDisabled = 3; + bool SkipPTRProcess = 4; } // SimpleRecord represents a dns.SimpleRecord From 7285fef0f041483719b0d926fab70d315e7d3314 Mon Sep 17 00:00:00 2001 From: shuuri-labs <61762328+shuuri-labs@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:51:16 +0100 Subject: [PATCH 090/120] feat: Add support for displaying device code (UserCode) on Android TV SSO flow (#4800) - Modified URLOpener interface to pass userCode alongside URL in login.go - added ability to force device auth flow --- client/android/client.go | 4 ++-- client/android/login.go | 16 ++++++++-------- client/cmd/login.go | 2 +- client/internal/auth/oauth.go | 9 +++++++-- client/ios/NetBirdSDK/client.go | 2 +- client/server/server.go | 4 ++-- 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/client/android/client.go b/client/android/client.go index 2943702c6..0d5474c4b 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -92,7 +92,7 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { +func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, @@ -115,7 +115,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead c.ctxCancelLock.Unlock() auth := NewAuthWithConfig(ctx, cfg) - err = auth.login(urlOpener) + err = auth.login(urlOpener, isAndroidTV) if err != nil { return err } diff --git a/client/android/login.go b/client/android/login.go index 16df24ba8..4d4c7a650 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -32,7 +32,7 @@ type ErrListener interface { // URLOpener it is a callback interface. The Open function will be triggered if // the backend want to show an url for the user type URLOpener interface { - Open(string) + Open(url string, userCode string) OnLoginSuccess() } @@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string } // Login try register the client on the server -func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) { +func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) { go func() { - err := a.login(urlOpener) + err := a.login(urlOpener, isAndroidTV) if err != nil { resultListener.OnError(err) } else { @@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) { }() } -func (a *Auth) login(urlOpener URLOpener) error { +func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error { var needsLogin bool // check if we need to generate JWT token @@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error { jwtToken := "" if needsLogin { - tokenInfo, err := a.foregroundGetTokenInfo(urlOpener) + tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } @@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error { return nil } -func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "") +func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) { + oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "") if err != nil { return nil, err } @@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) } - go urlOpener.Open(flowInfo.VerificationURIComplete) + go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode) waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout) diff --git a/client/cmd/login.go b/client/cmd/login.go index b0c877faa..2ddcccc8a 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -332,7 +332,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro hint = profileState.Email } - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint) + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint) if err != nil { return nil, err } diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 9fbd6cf5f..85a166005 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -60,14 +60,19 @@ func (t TokenInfo) GetTokenToUse() string { return t.AccessToken } +func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool { + return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient +} + // NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration // // It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow, // and if that also fails, the authentication process is deemed unsuccessful // // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow -func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) { - if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { +// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV) +func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) { + if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) { return authenticateWithDeviceCodeFlow(ctx, config, hint) } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index b0d377c21..6d969bb12 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string { ConfigPath: c.cfgFile, }) - oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "") + oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "") if err != nil { return err.Error() } diff --git a/client/server/server.go b/client/server/server.go index a930e8a02..49000c092 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -504,7 +504,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro if msg.Hint != nil { hint = *msg.Hint } - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint) + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, false, hint) if err != nil { state.Set(internal.StatusLoginFailed) return nil, err @@ -1235,7 +1235,7 @@ func (s *Server) RequestJWTAuth( } isDesktop := isUnixRunningDesktop() - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint) + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, false, hint) if err != nil { return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err) } From f31bba87b44c19472bed2162e3e189bf335127c9 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 26 Nov 2025 17:07:44 +0300 Subject: [PATCH 091/120] [management] Preserve validator settings on account settings update (#4862) --- management/server/account.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/management/server/account.go b/management/server/account.go index 3e498536c..716d5ab5d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -325,6 +325,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } } + newSettings.Extra.IntegratedValidatorGroups = oldSettings.Extra.IntegratedValidatorGroups + newSettings.Extra.IntegratedValidator = oldSettings.Extra.IntegratedValidator + if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil { return err } From 02200d790bb8a5466f3b824ea6c6c2e624240b4a Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:06:47 +0100 Subject: [PATCH 092/120] [client] Open browser for ssh automatically (#4838) --- client/cmd/login.go | 12 +----------- client/cmd/ssh.go | 33 ++++++++++++++++++++++++++++++- client/ssh/client/client.go | 18 ++++++++++++----- client/ssh/common.go | 36 ++++++++++++++++++++++++++++------ client/ssh/proxy/proxy.go | 30 +++++++++++++++------------- client/ssh/proxy/proxy_test.go | 2 +- util/common.go | 15 +++++++++++++- 7 files changed, 107 insertions(+), 39 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 2ddcccc8a..a34bb7c70 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,14 +4,12 @@ import ( "context" "fmt" "os" - "os/exec" "os/user" "runtime" "strings" "time" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := openBrowser(verificationURIComplete); err != nil { + if err := util.OpenBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } } -// openBrowser opens the URL in a browser, respecting the BROWSER environment variable. -func openBrowser(url string) error { - if browser := os.Getenv("BROWSER"); browser != "" { - return exec.Command(browser, url).Start() - } - return open.Run(url) -} - // isUnixRunningDesktop checks if a Linux OS is running desktop environment func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 92857c637..525bcdef1 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -51,6 +51,7 @@ var ( identityFile string skipCachedToken bool requestPTY bool + sshNoBrowser bool ) var ( @@ -81,6 +82,7 @@ func init() { sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)") _ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used") sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc) sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport") @@ -185,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string { return defaultValue } +// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes +func getBoolEnvOrDefault(flagName string, defaultValue bool) bool { + if envValue := os.Getenv("WT_" + flagName); envValue != "" { + if parsed, err := strconv.ParseBool(envValue); err == nil { + return parsed + } + } + if envValue := os.Getenv("NB_" + flagName); envValue != "" { + if parsed, err := strconv.ParseBool(envValue); err == nil { + return parsed + } + } + return defaultValue +} + // resetSSHGlobals sets SSH globals to their default values func resetSSHGlobals() { port = sshserver.DefaultSSHPort @@ -196,6 +213,7 @@ func resetSSHGlobals() { strictHostKeyChecking = true knownHostsFile = "" identityFile = "" + sshNoBrowser = false } // parseCustomSSHFlags extracts -L, -R flags and returns filtered args @@ -370,6 +388,7 @@ type sshFlags struct { KnownHostsFile string IdentityFile string SkipCachedToken bool + NoBrowser bool ConfigPath string LogLevel string LocalForwards []string @@ -381,6 +400,7 @@ type sshFlags struct { func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { defaultConfigPath := getEnvOrDefault("CONFIG", configPath) defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false) fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError) fs.SetOutput(nil) @@ -401,6 +421,7 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file") fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file") fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc) fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location") fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location") @@ -449,6 +470,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error { knownHostsFile = flags.KnownHostsFile identityFile = flags.IdentityFile skipCachedToken = flags.SkipCachedToken + sshNoBrowser = flags.NoBrowser if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) { configPath = flags.ConfigPath @@ -508,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { DaemonAddr: daemonAddr, SkipCachedToken: skipCachedToken, InsecureSkipVerify: !strictHostKeyChecking, + NoBrowser: sshNoBrowser, }) if err != nil { @@ -763,7 +786,15 @@ func sshProxyFn(cmd *cobra.Command, args []string) error { return fmt.Errorf("invalid port: %s", portStr) } - proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr()) + // Check env var for browser setting since this command is invoked via SSH ProxyCommand + // where command-line flags cannot be passed. Default is to open browser. + noBrowser := getBoolEnvOrDefault("NO_BROWSER", false) + var browserOpener func(string) error + if !noBrowser { + browserOpener = util.OpenBrowser + } + + proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener) if err != nil { return fmt.Errorf("create SSH proxy: %w", err) } diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 31b80317a..aab222093 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/proto" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/util" ) const ( @@ -278,6 +279,7 @@ type DialOptions struct { DaemonAddr string SkipCachedToken bool InsecureSkipVerify bool + NoBrowser bool } // Dial connects to the given ssh server with specified options @@ -307,7 +309,7 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er config.Auth = append(config.Auth, authMethod) } - return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken) + return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser) } // dialSSH establishes an SSH connection without JWT authentication @@ -333,7 +335,7 @@ func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig } // dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection -func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) { +func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) { host, portStr, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("parse address %s: %w", addr, err) @@ -359,7 +361,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout) defer cancel() - jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache) + jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser) if err != nil { return nil, fmt.Errorf("request JWT token: %w", err) } @@ -369,7 +371,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo } // requestJWTToken requests a JWT token from the NetBird daemon -func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) { +func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) { hint := profilemanager.GetLoginHint() conn, err := connectToDaemon(daemonAddr) @@ -379,7 +381,13 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st defer conn.Close() client := proto.NewDaemonServiceClient(conn) - return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint) + + var browserOpener func(string) error + if !noBrowser { + browserOpener = util.OpenBrowser + } + + return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener) } // verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon diff --git a/client/ssh/common.go b/client/ssh/common.go index 3beb12806..6574437b5 100644 --- a/client/ssh/common.go +++ b/client/ssh/common.go @@ -67,8 +67,31 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe return VerifyHostKey(storedKeyData, presentedKey, peerAddress) } +// printAuthInstructions prints authentication instructions to stderr +func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) { + _, _ = fmt.Fprintln(stderr, "SSH authentication required.") + + if browserWillOpen { + _, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.") + _, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:") + _, _ = fmt.Fprintln(stderr) + } + + _, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete) + + if authResponse.UserCode != "" { + _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) + } + + if browserWillOpen { + _, _ = fmt.Fprintln(stderr) + } + + _, _ = fmt.Fprintln(stderr, "Waiting for authentication...") +} + // RequestJWTToken requests or retrieves a JWT token for SSH authentication -func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) { +func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) { req := &proto.RequestJWTAuthRequest{} if hint != "" { req.Hint = &hint @@ -84,12 +107,13 @@ func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdo } if stderr != nil { - _, _ = fmt.Fprintln(stderr, "SSH authentication required.") - _, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete) - if authResponse.UserCode != "" { - _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) + printAuthInstructions(stderr, authResponse, openBrowser != nil) + } + + if openBrowser != nil { + if err := openBrowser(authResponse.VerificationURIComplete); err != nil { + log.Debugf("open browser: %v", err) } - _, _ = fmt.Fprintln(stderr, "Waiting for authentication...") } tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{ diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index bc8a84b89..4e807e33c 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -35,15 +35,16 @@ const ( ) type SSHProxy struct { - daemonAddr string - targetHost string - targetPort int - stderr io.Writer - conn *grpc.ClientConn - daemonClient proto.DaemonServiceClient + daemonAddr string + targetHost string + targetPort int + stderr io.Writer + conn *grpc.ClientConn + daemonClient proto.DaemonServiceClient + browserOpener func(string) error } -func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) { +func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) { grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://") grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -51,12 +52,13 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP } return &SSHProxy{ - daemonAddr: daemonAddr, - targetHost: targetHost, - targetPort: targetPort, - stderr: stderr, - conn: grpcConn, - daemonClient: proto.NewDaemonServiceClient(grpcConn), + daemonAddr: daemonAddr, + targetHost: targetHost, + targetPort: targetPort, + stderr: stderr, + conn: grpcConn, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + browserOpener: browserOpener, }, nil } @@ -70,7 +72,7 @@ func (p *SSHProxy) Close() error { func (p *SSHProxy) Connect(ctx context.Context) error { hint := profilemanager.GetLoginHint() - jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint) + jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint, p.browserOpener) if err != nil { return fmt.Errorf(jwtAuthErrorMsg, err) } diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go index c5036da37..582f9c07b 100644 --- a/client/ssh/proxy/proxy_test.go +++ b/client/ssh/proxy/proxy_test.go @@ -153,7 +153,7 @@ func TestSSHProxy_Connect(t *testing.T) { validToken := generateValidJWT(t, privateKey, issuer, audience) mockDaemon.setJWTToken(validToken) - proxyInstance, err := New(mockDaemon.addr, host, port, nil) + proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil) require.NoError(t, err) clientConn, proxyConn := net.Pipe() diff --git a/util/common.go b/util/common.go index 27adb9d13..89903b609 100644 --- a/util/common.go +++ b/util/common.go @@ -1,6 +1,19 @@ package util -import "os" +import ( + "os" + "os/exec" + + "github.com/skratchdot/open-golang/open" +) + +// OpenBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func OpenBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} // SliceDiff returns the elements in slice `x` that are not in slice `y` func SliceDiff(x, y []string) []string { From aca0398105fd0662c09533e1368a8682310efd94 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 26 Nov 2025 16:07:45 +0100 Subject: [PATCH 093/120] [client] Add excluded port range handling for PKCE flow (#4853) --- client/internal/auth/pkce_flow.go | 46 +++++- client/internal/auth/pkce_flow_other.go | 8 + client/internal/auth/pkce_flow_test.go | 147 ++++++++++++++++++ client/internal/auth/pkce_flow_windows.go | 86 ++++++++++ .../internal/auth/pkce_flow_windows_test.go | 116 ++++++++++++++ 5 files changed, 398 insertions(+), 5 deletions(-) create mode 100644 client/internal/auth/pkce_flow_other.go create mode 100644 client/internal/auth/pkce_flow_windows.go create mode 100644 client/internal/auth/pkce_flow_windows_test.go diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 48873f640..39da2a79f 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/url" + "strconv" "strings" "time" @@ -46,9 +47,10 @@ type PKCEAuthorizationFlow struct { func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { var availableRedirectURL string - // find the first available redirect URL + excludedRanges := getSystemExcludedPortRanges() + for _, redirectURL := range config.RedirectURLs { - if !isRedirectURLPortUsed(redirectURL) { + if !isRedirectURLPortUsed(redirectURL, excludedRanges) { availableRedirectURL = redirectURL break } @@ -282,15 +284,22 @@ func createCodeChallenge(codeVerifier string) string { return base64.RawURLEncoding.EncodeToString(sha2[:]) } -// isRedirectURLPortUsed checks if the port used in the redirect URL is in use. -func isRedirectURLPortUsed(redirectURL string) bool { +// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows. +func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool { parsedURL, err := url.Parse(redirectURL) if err != nil { log.Errorf("failed to parse redirect URL: %v", err) return true } - addr := fmt.Sprintf(":%s", parsedURL.Port()) + port := parsedURL.Port() + + if isPortInExcludedRange(port, excludedRanges) { + log.Warnf("port %s is in Windows excluded port range, skipping", port) + return true + } + + addr := fmt.Sprintf(":%s", port) conn, err := net.DialTimeout("tcp", addr, 3*time.Second) if err != nil { return false @@ -304,6 +313,33 @@ func isRedirectURLPortUsed(redirectURL string) bool { return true } +// excludedPortRange represents a range of excluded ports. +type excludedPortRange struct { + start int + end int +} + +// isPortInExcludedRange checks if the given port is in any of the excluded ranges. +func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool { + if len(excludedRanges) == 0 { + return false + } + + portNum, err := strconv.Atoi(port) + if err != nil { + log.Debugf("invalid port number %s: %v", port, err) + return false + } + + for _, r := range excludedRanges { + if portNum >= r.start && portNum <= r.end { + return true + } + } + + return false +} + func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) { tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl) if err != nil { diff --git a/client/internal/auth/pkce_flow_other.go b/client/internal/auth/pkce_flow_other.go new file mode 100644 index 000000000..96df41539 --- /dev/null +++ b/client/internal/auth/pkce_flow_other.go @@ -0,0 +1,8 @@ +//go:build !windows + +package auth + +// getSystemExcludedPortRanges returns nil on non-Windows platforms. +func getSystemExcludedPortRanges() []excludedPortRange { + return nil +} diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index b2347d12d..380a360e5 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -2,8 +2,11 @@ package auth import ( "context" + "fmt" + "net" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/internal" @@ -69,3 +72,147 @@ func TestPromptLogin(t *testing.T) { }) } } + +func TestIsPortInExcludedRange(t *testing.T) { + tests := []struct { + name string + port string + excludedRanges []excludedPortRange + expectedBlocked bool + }{ + { + name: "Port in excluded range", + port: "8080", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port at start of range", + port: "8000", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port at end of range", + port: "8100", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: true, + }, + { + name: "Port before range", + port: "7999", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Port after range", + port: "8101", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Empty excluded ranges", + port: "8080", + excludedRanges: []excludedPortRange{}, + expectedBlocked: false, + }, + { + name: "Nil excluded ranges", + port: "8080", + excludedRanges: nil, + expectedBlocked: false, + }, + { + name: "Multiple ranges - port in second range", + port: "9050", + excludedRanges: []excludedPortRange{ + {start: 8000, end: 8100}, + {start: 9000, end: 9100}, + }, + expectedBlocked: true, + }, + { + name: "Multiple ranges - port not in any range", + port: "8500", + excludedRanges: []excludedPortRange{ + {start: 8000, end: 8100}, + {start: 9000, end: 9100}, + }, + expectedBlocked: false, + }, + { + name: "Invalid port string", + port: "invalid", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + { + name: "Empty port string", + port: "", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedBlocked: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isPortInExcludedRange(tt.port, tt.excludedRanges) + assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch") + }) + } +} + +func TestIsRedirectURLPortUsed(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener.Close() + }() + + usedPort := listener.Addr().(*net.TCPAddr).Port + + tests := []struct { + name string + redirectURL string + excludedRanges []excludedPortRange + expectedUsed bool + }{ + { + name: "Port in excluded range", + redirectURL: "http://127.0.0.1:8080/", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedUsed: true, + }, + { + name: "Port actually in use", + redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort), + excludedRanges: nil, + expectedUsed: true, + }, + { + name: "Port not in use and not excluded", + redirectURL: "http://127.0.0.1:65432/", + excludedRanges: nil, + expectedUsed: false, + }, + { + name: "Invalid URL without port", + redirectURL: "not-a-valid-url", + excludedRanges: nil, + expectedUsed: false, + }, + { + name: "Port excluded even if not in use", + redirectURL: "http://127.0.0.1:8050/", + excludedRanges: []excludedPortRange{{start: 8000, end: 8100}}, + expectedUsed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges) + assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch") + }) + } +} diff --git a/client/internal/auth/pkce_flow_windows.go b/client/internal/auth/pkce_flow_windows.go new file mode 100644 index 000000000..cf3f8718f --- /dev/null +++ b/client/internal/auth/pkce_flow_windows.go @@ -0,0 +1,86 @@ +//go:build windows + +package auth + +import ( + "bufio" + "fmt" + "os/exec" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" +) + +// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh. +func getSystemExcludedPortRanges() []excludedPortRange { + ranges, err := getExcludedPortRangesFromNetsh() + if err != nil { + log.Debugf("failed to get Windows excluded port ranges: %v", err) + return nil + } + + return ranges +} + +// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command. +func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) { + cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("netsh command: %w", err) + } + + return parseExcludedPortRanges(string(output)) +} + +// parseExcludedPortRanges parses the output of the netsh command to extract port ranges. +func parseExcludedPortRanges(output string) ([]excludedPortRange, error) { + var ranges []excludedPortRange + scanner := bufio.NewScanner(strings.NewReader(output)) + + foundHeader := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") { + foundHeader = true + continue + } + + if !foundHeader { + continue + } + + if strings.Contains(line, "----------") { + continue + } + + if line == "" { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + startPort, err := strconv.Atoi(fields[0]) + if err != nil { + continue + } + + endPort, err := strconv.Atoi(fields[1]) + if err != nil { + continue + } + + ranges = append(ranges, excludedPortRange{start: startPort, end: endPort}) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("scan output: %w", err) + } + + return ranges, nil +} diff --git a/client/internal/auth/pkce_flow_windows_test.go b/client/internal/auth/pkce_flow_windows_test.go new file mode 100644 index 000000000..dd455b2fe --- /dev/null +++ b/client/internal/auth/pkce_flow_windows_test.go @@ -0,0 +1,116 @@ +//go:build windows + +package auth + +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" +) + +func TestParseExcludedPortRanges(t *testing.T) { + tests := []struct { + name string + netshOutput string + expectedRanges []excludedPortRange + expectError bool + }{ + { + name: "Valid netsh output with multiple ranges", + netshOutput: ` +Protocol tcp Dynamic Port Range +--------------------------------- +Start Port : 49152 +Number of Ports : 16384 + +Protocol tcp Excluded Port Ranges +--------------------------------- +Start Port End Port +---------- -------- + 5357 5357 * + 50000 50059 * +`, + expectedRanges: []excludedPortRange{ + {start: 5357, end: 5357}, + {start: 50000, end: 50059}, + }, + expectError: false, + }, + { + name: "Empty output", + netshOutput: ` +Protocol tcp Dynamic Port Range +--------------------------------- +Start Port : 49152 +Number of Ports : 16384 +`, + expectedRanges: nil, + expectError: false, + }, + { + name: "Single range", + netshOutput: ` +Protocol tcp Excluded Port Ranges +--------------------------------- +Start Port End Port +---------- -------- + 8080 8090 +`, + expectedRanges: []excludedPortRange{ + {start: 8080, end: 8090}, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ranges, err := parseExcludedPortRanges(tt.netshOutput) + + if tt.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedRanges, ranges) + } + }) + } +} + +func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) { + ranges := getSystemExcludedPortRanges() + t.Logf("Found %d excluded port ranges on this system", len(ranges)) + + listener1, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener1.Close() + }() + usedPort1 := listener1.Addr().(*net.TCPAddr).Port + + availablePort := 65432 + + config := internal.PKCEAuthProviderConfig{ + ClientID: "test-client-id", + Audience: "test-audience", + TokenEndpoint: "https://test-token-endpoint.com/token", + Scope: "openid email profile", + AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize", + RedirectURLs: []string{ + fmt.Sprintf("http://127.0.0.1:%d/", usedPort1), + fmt.Sprintf("http://127.0.0.1:%d/", availablePort), + }, + UseIDToken: true, + } + + flow, err := NewPKCEAuthorizationFlow(config) + require.NoError(t, err) + require.NotNil(t, flow) + assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort), + "Should skip port in use and select available port") +} From ddcd182859116b7ac30a7934fcc7e40fba818591 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 28 Nov 2025 17:26:22 +0100 Subject: [PATCH 094/120] [client] Sleep detection on macOS (#4859) A macOS-specific sleep detection mechanism using IOKit and CoreFoundation via cgo is introduced, with a fallback implementation for unsupported platforms. A public Service wrapper provides an event-driven API translating system sleep/wake events into gRPC calls. The UI client integrates sleep detection to manage connectivity state based on system sleep status. --- client/internal/sleep/detector_darwin.go | 218 ++++++++++++++++++ .../internal/sleep/detector_notsupported.go | 9 + client/internal/sleep/service.go | 36 +++ client/proto/daemon.pb.go | 2 +- client/ui/client_ui.go | 69 +++++- 5 files changed, 329 insertions(+), 5 deletions(-) create mode 100644 client/internal/sleep/detector_darwin.go create mode 100644 client/internal/sleep/detector_notsupported.go create mode 100644 client/internal/sleep/service.go diff --git a/client/internal/sleep/detector_darwin.go b/client/internal/sleep/detector_darwin.go new file mode 100644 index 000000000..3d6747ed1 --- /dev/null +++ b/client/internal/sleep/detector_darwin.go @@ -0,0 +1,218 @@ +//go:build darwin && !ios + +package sleep + +/* +#cgo LDFLAGS: -framework IOKit -framework CoreFoundation +#include +#include +#include + +extern void sleepCallbackBridge(); +extern void poweredOnCallbackBridge(); +extern void suspendedCallbackBridge(); +extern void resumedCallbackBridge(); + + +// C global variables for IOKit state +static IONotificationPortRef g_notifyPortRef = NULL; +static io_object_t g_notifierObject = 0; +static io_object_t g_generalInterestNotifier = 0; +static io_connect_t g_rootPort = 0; +static CFRunLoopRef g_runLoop = NULL; + +static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) { + switch (messageType) { + case kIOMessageSystemWillSleep: + sleepCallbackBridge(); + IOAllowPowerChange(g_rootPort, (long)messageArgument); + break; + case kIOMessageSystemHasPoweredOn: + poweredOnCallbackBridge(); + break; + case kIOMessageServiceIsSuspended: + suspendedCallbackBridge(); + break; + case kIOMessageServiceIsResumed: + resumedCallbackBridge(); + break; + default: + break; + } +} + +static void registerNotifications() { + g_rootPort = IORegisterForSystemPower( + NULL, + &g_notifyPortRef, + (IOServiceInterestCallback)sleepCallback, + &g_notifierObject + ); + + if (g_rootPort == 0) { + return; + } + + CFRunLoopAddSource(CFRunLoopGetCurrent(), + IONotificationPortGetRunLoopSource(g_notifyPortRef), + kCFRunLoopCommonModes); + + g_runLoop = CFRunLoopGetCurrent(); + CFRunLoopRun(); +} + +static void unregisterNotifications() { + CFRunLoopRemoveSource(g_runLoop, + IONotificationPortGetRunLoopSource(g_notifyPortRef), + kCFRunLoopCommonModes); + + IODeregisterForSystemPower(&g_notifierObject); + IOServiceClose(g_rootPort); + IONotificationPortDestroy(g_notifyPortRef); + CFRunLoopStop(g_runLoop); + + g_notifyPortRef = NULL; + g_notifierObject = 0; + g_rootPort = 0; + g_runLoop = NULL; +} + +*/ +import "C" + +import ( + "context" + "fmt" + "runtime" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +var ( + serviceRegistry = make(map[*Detector]struct{}) + serviceRegistryMu sync.Mutex +) + +//export sleepCallbackBridge +func sleepCallbackBridge() { + log.Info("sleepCallbackBridge event triggered") + + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + + for svc := range serviceRegistry { + svc.triggerCallback(EventTypeSleep) + } +} + +//export resumedCallbackBridge +func resumedCallbackBridge() { + log.Info("resumedCallbackBridge event triggered") +} + +//export suspendedCallbackBridge +func suspendedCallbackBridge() { + log.Info("suspendedCallbackBridge event triggered") +} + +//export poweredOnCallbackBridge +func poweredOnCallbackBridge() { + log.Info("poweredOnCallbackBridge event triggered") + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + + for svc := range serviceRegistry { + svc.triggerCallback(EventTypeWakeUp) + } +} + +type Detector struct { + callback func(event EventType) + ctx context.Context + cancel context.CancelFunc +} + +func NewDetector() (*Detector, error) { + return &Detector{}, nil +} + +func (d *Detector) Register(callback func(event EventType)) error { + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + + if _, exists := serviceRegistry[d]; exists { + return fmt.Errorf("detector service already registered") + } + + d.callback = callback + + d.ctx, d.cancel = context.WithCancel(context.Background()) + + if len(serviceRegistry) > 0 { + serviceRegistry[d] = struct{}{} + return nil + } + + serviceRegistry[d] = struct{}{} + + // CFRunLoop must run on a single fixed OS thread + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + C.registerNotifications() + }() + + log.Info("sleep detection service started on macOS") + return nil +} + +// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down +// and the runloop is stopped and cleaned up. +func (d *Detector) Deregister() error { + serviceRegistryMu.Lock() + defer serviceRegistryMu.Unlock() + _, exists := serviceRegistry[d] + if !exists { + return nil + } + + // cancel and remove this detector + d.cancel() + delete(serviceRegistry, d) + + // If other Detectors still exist, leave IOKit running + if len(serviceRegistry) > 0 { + return nil + } + + log.Info("sleep detection service stopping (deregister)") + + // Deregister IOKit notifications, stop runloop, and free resources + C.unregisterNotifications() + + return nil +} + +func (d *Detector) triggerCallback(event EventType) { + doneChan := make(chan struct{}) + + timeout := time.NewTimer(500 * time.Millisecond) + defer timeout.Stop() + + cb := d.callback + go func(callback func(event EventType)) { + log.Info("sleep detection event fired") + callback(event) + close(doneChan) + }(cb) + + select { + case <-doneChan: + case <-d.ctx.Done(): + case <-timeout.C: + log.Warnf("sleep callback timed out") + } +} diff --git a/client/internal/sleep/detector_notsupported.go b/client/internal/sleep/detector_notsupported.go new file mode 100644 index 000000000..6323bf5d1 --- /dev/null +++ b/client/internal/sleep/detector_notsupported.go @@ -0,0 +1,9 @@ +//go:build !darwin || ios + +package sleep + +import "fmt" + +func NewDetector() (detector, error) { + return nil, fmt.Errorf("sleep not supported on this platform") +} diff --git a/client/internal/sleep/service.go b/client/internal/sleep/service.go new file mode 100644 index 000000000..35fc933c0 --- /dev/null +++ b/client/internal/sleep/service.go @@ -0,0 +1,36 @@ +package sleep + +var ( + EventTypeSleep EventType = 0 + EventTypeWakeUp EventType = 1 +) + +type EventType int + +type detector interface { + Register(callback func(eventType EventType)) error + Deregister() error +} + +type Service struct { + detector detector +} + +func New() (*Service, error) { + d, err := NewDetector() + if err != nil { + return nil, err + } + + return &Service{ + detector: d, + }, nil +} + +func (s *Service) Register(callback func(eventType EventType)) error { + return s.detector.Register(callback) +} + +func (s *Service) Deregister() error { + return s.detector.Deregister() +} diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 7b9ae25f7..6f8255615 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.32.1 +// protoc v3.21.12 // source: daemon.proto package proto diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 44643616d..57d0e74a0 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -38,6 +38,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/internal/sleep" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/event" @@ -209,10 +210,11 @@ var iconConnectedDot []byte var iconDisconnectedDot []byte type serviceClient struct { - ctx context.Context - cancel context.CancelFunc - addr string - conn proto.DaemonServiceClient + ctx context.Context + cancel context.CancelFunc + addr string + conn proto.DaemonServiceClient + connLock sync.Mutex eventHandler *eventHandler @@ -1098,6 +1100,9 @@ func (s *serviceClient) onTrayReady() { go s.eventManager.Start(s.ctx) go s.eventHandler.listen(s.ctx) + + // Start sleep detection listener + go s.startSleepListener() } func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { @@ -1134,6 +1139,8 @@ func (s *serviceClient) onTrayExit() { // getSrvClient connection to the service. func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) { + s.connLock.Lock() + defer s.connLock.Unlock() if s.conn != nil { return s.conn, nil } @@ -1156,6 +1163,60 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService return s.conn, nil } +// startSleepListener initializes the sleep detection service and listens for sleep events +func (s *serviceClient) startSleepListener() { + sleepService, err := sleep.New() + if err != nil { + log.Warnf("%v", err) + return + } + + if err := sleepService.Register(s.handleSleepEvents); err != nil { + log.Errorf("failed to start sleep detection: %v", err) + return + } + + log.Info("sleep detection service initialized") + + // Cleanup on context cancellation + go func() { + <-s.ctx.Done() + log.Info("stopping sleep event listener") + if err := sleepService.Deregister(); err != nil { + log.Errorf("failed to deregister sleep detection: %v", err) + } + }() +} + +// handleSleepEvents sends a sleep notification to the daemon via gRPC +func (s *serviceClient) handleSleepEvents(event sleep.EventType) { + conn, err := s.getSrvClient(0) + if err != nil { + log.Errorf("failed to get daemon client for sleep notification: %v", err) + return + } + + switch event { + case sleep.EventTypeWakeUp: + log.Infof("handle wakeup event: %v", event) + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + return + } + return + case sleep.EventTypeSleep: + log.Infof("handle sleep event: %v", event) + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + return + } + } + + log.Info("successfully notified daemon about sleep/wakeup event") +} + // setSettingsEnabled enables or disables the settings menu based on the provided state func (s *serviceClient) setSettingsEnabled(enabled bool) { if s.mSettings != nil { From cb83b7c0d3ddbacdd6ff391bdcc405f8b2ad00f0 Mon Sep 17 00:00:00 2001 From: shuuri-labs <61762328+shuuri-labs@users.noreply.github.com> Date: Fri, 28 Nov 2025 21:53:53 +0100 Subject: [PATCH 095/120] [relay] use exposed address for healthcheck TLS validation (#4872) * fix(relay): use exposed address for healthcheck TLS validation Healthcheck was using listen address (0.0.0.0) instead of exposed address (domain name) for certificate validation, causing validation to always fail. Now correctly uses the exposed address where the TLS certificate is valid, matching real client connection behavior. * - store exposedAddress directly in Relay struct instead of parsing on every call - remove unused parseHostPort() function - remove unused ListenAddress() method from ServiceChecker interface - improve error logging with address context * [relay/healthcheck] Remove QUIC health check logic, update WebSocket validation flow Refactored health check logic by removing QUIC-specific connection validation and simplifying logic for WebSocket protocol. Adjusted certificate validation flow and improved handling of exposed addresses. * [relay/healthcheck] Fix certificate validation status during health check --------- Co-authored-by: Maycon Santos --- relay/healthcheck/healthcheck.go | 44 ++++++++++++-------------------- relay/healthcheck/quic.go | 31 ---------------------- relay/healthcheck/ws.go | 12 +++++++-- relay/server/relay.go | 27 ++++++++++++-------- relay/server/server.go | 8 ++---- 5 files changed, 46 insertions(+), 76 deletions(-) delete mode 100644 relay/healthcheck/quic.go diff --git a/relay/healthcheck/healthcheck.go b/relay/healthcheck/healthcheck.go index eedd62394..6463843eb 100644 --- a/relay/healthcheck/healthcheck.go +++ b/relay/healthcheck/healthcheck.go @@ -6,14 +6,13 @@ import ( "errors" "net" "net/http" + "strings" "sync" "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" - "github.com/netbirdio/netbird/relay/server/listener/quic" - "github.com/netbirdio/netbird/relay/server/listener/ws" ) const ( @@ -27,7 +26,7 @@ const ( type ServiceChecker interface { ListenerProtocols() []protocol.Protocol - ListenAddress() string + ExposedAddress() string } type HealthStatus struct { @@ -135,7 +134,11 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) { } status.Listeners = listeners - if ok := s.validateCertificate(ctx); !ok { + if !strings.HasPrefix(s.config.ServiceChecker.ExposedAddress(), "rels") { + status.CertificateValid = false + } + + if ok := s.validateConnection(ctx); !ok { status.Status = statusUnhealthy status.CertificateValid = false healthy = false @@ -152,32 +155,18 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) { return listeners, true } -func (s *Server) validateCertificate(ctx context.Context) bool { - listenAddress := s.config.ServiceChecker.ListenAddress() - if listenAddress == "" { - log.Warn("listen address is empty") +func (s *Server) validateConnection(ctx context.Context) bool { + exposedAddress := s.config.ServiceChecker.ExposedAddress() + if exposedAddress == "" { + log.Error("exposed address is empty, cannot validate certificate") return false } - dAddr := dialAddress(listenAddress) - - for _, proto := range s.config.ServiceChecker.ListenerProtocols() { - switch proto { - case ws.Proto: - if err := dialWS(ctx, dAddr); err != nil { - log.Errorf("failed to dial WebSocket listener: %v", err) - return false - } - case quic.Proto: - if err := dialQUIC(ctx, dAddr); err != nil { - log.Errorf("failed to dial QUIC listener: %v", err) - return false - } - default: - log.Warnf("unknown protocol for healthcheck: %s", proto) - return false - } + if err := dialWS(ctx, exposedAddress); err != nil { + log.Errorf("failed to dial WebSocket listener at %s: %v", exposedAddress, err) + return false } + return true } @@ -187,8 +176,9 @@ func dialAddress(listenAddress string) string { return listenAddress // fallback, might be invalid for dialing } + // When listening on all interfaces, show localhost for better readability if host == "" || host == "::" || host == "0.0.0.0" { - host = "0.0.0.0" + host = "localhost" } return net.JoinHostPort(host, port) diff --git a/relay/healthcheck/quic.go b/relay/healthcheck/quic.go deleted file mode 100644 index 1582edf7b..000000000 --- a/relay/healthcheck/quic.go +++ /dev/null @@ -1,31 +0,0 @@ -package healthcheck - -import ( - "context" - "crypto/tls" - "fmt" - "time" - - "github.com/quic-go/quic-go" - - tlsnb "github.com/netbirdio/netbird/shared/relay/tls" -) - -func dialQUIC(ctx context.Context, address string) error { - tlsConfig := &tls.Config{ - InsecureSkipVerify: false, // Keep certificate validation enabled - NextProtos: []string{tlsnb.NBalpn}, - } - - conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{ - MaxIdleTimeout: 30 * time.Second, - KeepAlivePeriod: 10 * time.Second, - EnableDatagrams: true, - }) - if err != nil { - return fmt.Errorf("failed to connect to QUIC server: %w", err) - } - - _ = conn.CloseWithError(0, "availability check complete") - return nil -} diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go index 49694356c..badd31219 100644 --- a/relay/healthcheck/ws.go +++ b/relay/healthcheck/ws.go @@ -3,6 +3,7 @@ package healthcheck import ( "context" "fmt" + "strings" "github.com/coder/websocket" @@ -10,12 +11,19 @@ import ( ) func dialWS(ctx context.Context, address string) error { - url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath) + addressSplit := strings.Split(address, "/") + scheme := "ws" + if addressSplit[0] == "rels:" { + scheme = "wss" + } + url := fmt.Sprintf("%s://%s%s", scheme, addressSplit[2], relay.WebSocketURLPath) conn, resp, err := websocket.Dial(ctx, url, nil) if resp != nil { defer func() { - _ = resp.Body.Close() + if resp.Body != nil { + _ = resp.Body.Close() + } }() } diff --git a/relay/server/relay.go b/relay/server/relay.go index d86684937..aab575bf0 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -51,10 +51,11 @@ type Relay struct { metricsCancel context.CancelFunc validator Validator - store *store.Store - notifier *store.PeerNotifier - instanceURL string - preparedMsg *preparedMsg + store *store.Store + notifier *store.PeerNotifier + instanceURL string + exposedAddress string + preparedMsg *preparedMsg closed bool closeMu sync.RWMutex @@ -87,12 +88,13 @@ func NewRelay(config Config) (*Relay, error) { } r := &Relay{ - metrics: m, - metricsCancel: metricsCancel, - validator: config.AuthValidator, - instanceURL: config.instanceURL, - store: store.NewStore(), - notifier: store.NewPeerNotifier(), + metrics: m, + metricsCancel: metricsCancel, + validator: config.AuthValidator, + instanceURL: config.instanceURL, + exposedAddress: config.ExposedAddress, + store: store.NewStore(), + notifier: store.NewPeerNotifier(), } r.preparedMsg, err = newPreparedMsg(r.instanceURL) @@ -178,3 +180,8 @@ func (r *Relay) Shutdown(ctx context.Context) { func (r *Relay) InstanceURL() string { return r.instanceURL } + +// ExposedAddress returns the exposed address (domain:port) where clients connect +func (r *Relay) ExposedAddress() string { + return r.exposedAddress +} diff --git a/relay/server/server.go b/relay/server/server.go index 4c30e7fdc..2c9e658d6 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -28,8 +28,6 @@ type ListenerConfig struct { // It is the gate between the WebSocket listener and the Relay server logic. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { - listenAddr string - relay *Relay listeners []listener.Listener listenerMux sync.Mutex @@ -62,8 +60,6 @@ func NewServer(config Config) (*Server, error) { // Listen starts the relay server. func (r *Server) Listen(cfg ListenerConfig) error { - r.listenAddr = cfg.Address - wSListener := &ws.Listener{ Address: cfg.Address, TLSConfig: cfg.TLSConfig, @@ -139,6 +135,6 @@ func (r *Server) ListenerProtocols() []protocol.Protocol { return result } -func (r *Server) ListenAddress() string { - return r.listenAddr +func (r *Server) ExposedAddress() string { + return r.relay.ExposedAddress() } From e47d815dd27526a04e54c212161152ba5a2a8dee Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 1 Dec 2025 14:16:03 +0100 Subject: [PATCH 096/120] Fix IsAnotherProcessRunning (#4858) Compare the exact process name rather than searching for a substring of the full path --- client/ui/process/process.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/ui/process/process.go b/client/ui/process/process.go index d0ef54896..28276f416 100644 --- a/client/ui/process/process.go +++ b/client/ui/process/process.go @@ -28,7 +28,8 @@ func IsAnotherProcessRunning() (int32, bool, error) { continue } - if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) { + runningProcessName := strings.ToLower(filepath.Base(runningProcessPath)) + if runningProcessName == processName && isProcessOwnedByCurrentUser(p) { return p.Pid, true, nil } } From 387d43bcc18eb51baf619005f11bb449cb1514f1 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 1 Dec 2025 14:25:52 +0100 Subject: [PATCH 097/120] [client, management] Add OAuth select_account prompt support to PKCE flow (#4880) * Add OAuth select_account prompt support to PKCE flow Extends LoginFlag enum with select_account options to enable multi-account selection during authentication. This allows users to choose which account to use when multiple accounts have active sessions with the identity provider. The new flags are backward compatible - existing LoginFlag values (0=prompt login, 1=max_age=0) retain their original behavior. --- client/internal/auth/pkce_flow.go | 10 ++++--- client/internal/auth/pkce_flow_test.go | 35 ++++++++++++++---------- shared/management/client/common/types.go | 25 +++++++++-------- 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 39da2a79f..c03376f5b 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/templates" + "github.com/netbirdio/netbird/shared/management/client/common" ) var _ OAuthFlow = &PKCEAuthorizationFlow{} @@ -104,11 +105,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), } if !p.providerConfig.DisablePromptLogin { - if p.providerConfig.LoginFlag.IsPromptLogin() { - params = append(params, oauth2.SetAuthURLParam("prompt", "login")) - } - if p.providerConfig.LoginFlag.IsMaxAge0Login() { + switch p.providerConfig.LoginFlag { + case common.LoginFlagPromptLogin: + params = append(params, oauth2.SetAuthURLParam("prompt", "login select_account")) + case common.LoginFlagMaxAge0: params = append(params, oauth2.SetAuthURLParam("max_age", "0")) + params = append(params, oauth2.SetAuthURLParam("prompt", "select_account")) } } if p.providerConfig.LoginHint != "" { diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index 380a360e5..b5843f104 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -15,30 +15,37 @@ import ( func TestPromptLogin(t *testing.T) { const ( - promptLogin = "prompt=login" - maxAge0 = "max_age=0" + promptSelectAccountLogin = "prompt=login+select_account" + promptSelectAccount = "prompt=select_account" + maxAge0 = "max_age=0" ) tt := []struct { name string loginFlag mgm.LoginFlag disablePromptLogin bool - expect string + expectContains []string }{ { - name: "Prompt login", - loginFlag: mgm.LoginFlagPrompt, - expect: promptLogin, + name: "Prompt login with select account", + loginFlag: mgm.LoginFlagPromptLogin, + expectContains: []string{promptSelectAccountLogin}, }, { - name: "Max age 0 login", - loginFlag: mgm.LoginFlagMaxAge0, - expect: maxAge0, + name: "Max age 0 with select account", + loginFlag: mgm.LoginFlagMaxAge0, + expectContains: []string{maxAge0, promptSelectAccount}, }, { name: "Disable prompt login", - loginFlag: mgm.LoginFlagPrompt, + loginFlag: mgm.LoginFlagPromptLogin, disablePromptLogin: true, + expectContains: []string{}, + }, + { + name: "None flag should not add parameters", + loginFlag: mgm.LoginFlagNone, + expectContains: []string{}, }, } @@ -53,6 +60,7 @@ func TestPromptLogin(t *testing.T) { RedirectURLs: []string{"http://127.0.0.1:33992/"}, UseIDToken: true, LoginFlag: tc.loginFlag, + DisablePromptLogin: tc.disablePromptLogin, } pkce, err := NewPKCEAuthorizationFlow(config) if err != nil { @@ -63,11 +71,8 @@ func TestPromptLogin(t *testing.T) { t.Fatalf("Failed to request auth info: %v", err) } - if !tc.disablePromptLogin { - require.Contains(t, authInfo.VerificationURIComplete, tc.expect) - } else { - require.Contains(t, authInfo.VerificationURIComplete, promptLogin) - require.NotContains(t, authInfo.VerificationURIComplete, maxAge0) + for _, expected := range tc.expectContains { + require.Contains(t, authInfo.VerificationURIComplete, expected) } }) } diff --git a/shared/management/client/common/types.go b/shared/management/client/common/types.go index 699617574..550bcde30 100644 --- a/shared/management/client/common/types.go +++ b/shared/management/client/common/types.go @@ -1,19 +1,20 @@ package common -// LoginFlag introduces additional login flags to the PKCE authorization request +// LoginFlag introduces additional login flags to the PKCE authorization request. +// +// # Config Values +// +// | Value | Flag | OAuth Parameters | +// |-------|----------------------|-----------------------------------------| +// | 0 | LoginFlagPromptLogin | prompt=select_account login | +// | 1 | LoginFlagMaxAge0 | max_age=0 & prompt=select_account | type LoginFlag uint8 const ( - // LoginFlagPrompt adds prompt=login to the authorization request - LoginFlagPrompt LoginFlag = iota - // LoginFlagMaxAge0 adds max_age=0 to the authorization request + // LoginFlagPromptLogin adds prompt=select_account login to the authorization request + LoginFlagPromptLogin LoginFlag = iota + // LoginFlagMaxAge0 adds max_age=0 and prompt=select_account to the authorization request LoginFlagMaxAge0 + // LoginFlagNone disables all login flags + LoginFlagNone ) - -func (l LoginFlag) IsPromptLogin() bool { - return l == LoginFlagPrompt -} - -func (l LoginFlag) IsMaxAge0Login() bool { - return l == LoginFlagMaxAge0 -} From 4b77359042c65d0d6cc4b7f19e76d4bf85de137a Mon Sep 17 00:00:00 2001 From: Fahri Shihab <22738442+fahrishih@users.noreply.github.com> Date: Mon, 1 Dec 2025 22:57:42 +0700 Subject: [PATCH 098/120] [management] Groups API with name query parameter (#4831) --- .../http/handlers/groups/groups_handler.go | 23 +++++ .../handlers/groups/groups_handler_test.go | 91 ++++++++++++++++++- shared/management/client/rest/groups.go | 25 +++++ shared/management/http/api/openapi.yml | 10 ++ shared/management/http/api/types.gen.go | 6 ++ 5 files changed, 154 insertions(+), 1 deletion(-) diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 208a2e828..56ccc9d0b 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -48,6 +48,29 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { } accountID, userID := userAuth.AccountId, userAuth.UserId + // Check if filtering by name + groupName := r.URL.Query().Get("name") + if groupName != "" { + // Get single group by name + group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + // Return as array with single element to maintain API consistency + groupsResponse := []*api.Group{toGroupResponse(accountPeers, group)} + util.WriteJSONObject(r.Context(), w, groupsResponse) + return + } + + // Get all groups groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index b7dd3944a..458a15c11 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -60,12 +60,23 @@ func initGroupTestData(initGroups ...*types.Group) *handler { return group, nil }, + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + groups := []*types.Group{ + {ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT}, + {ID: "id-existed", Name: "Existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI}, + {ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, + } + + groups = append(groups, initGroups...) + + return groups, nil + }, GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } - return nil, fmt.Errorf("unknown group name") + return nil, status.Errorf(status.NotFound, "unknown group name") }, GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { return maps.Values(TestPeers), nil @@ -287,6 +298,84 @@ func TestWriteGroup(t *testing.T) { } } +func TestGetAllGroups(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestType string + requestPath string + expectedCount int + }{ + { + name: "Get All Groups", + expectedBody: true, + requestType: http.MethodGet, + requestPath: "/api/groups", + expectedStatus: http.StatusOK, + expectedCount: 3, // id-jwt-group, id-existed, id-all + }, + { + name: "Get Group By Name - Existing", + expectedBody: true, + requestType: http.MethodGet, + requestPath: "/api/groups?name=All", + expectedStatus: http.StatusOK, + expectedCount: 1, + }, + { + name: "Get Group By Name - Not Found", + expectedBody: false, + requestType: http.MethodGet, + requestPath: "/api/groups?name=NonExistent", + expectedStatus: http.StatusNotFound, + }, + } + + p := initGroupTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + }) + + router := mux.NewRouter() + router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tc.expectedStatus) + return + } + + if !tc.expectedBody { + return + } + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + var groups []api.Group + if err = json.Unmarshal(content, &groups); err != nil { + t.Fatalf("Response is not in correct json format; %v", err) + } + + assert.Equal(t, tc.expectedCount, len(groups)) + }) + } +} + func TestDeleteGroup(t *testing.T) { tt := []struct { name string diff --git a/shared/management/client/rest/groups.go b/shared/management/client/rest/groups.go index af068e077..7cd9535dd 100644 --- a/shared/management/client/rest/groups.go +++ b/shared/management/client/rest/groups.go @@ -4,10 +4,14 @@ import ( "bytes" "context" "encoding/json" + "errors" "github.com/netbirdio/netbird/shared/management/http/api" ) +// ErrGroupNotFound is returned when a group is not found +var ErrGroupNotFound = errors.New("group not found") + // GroupsAPI APIs for Groups, do not use directly type GroupsAPI struct { c *Client @@ -27,6 +31,27 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) { return ret, err } +// GetByName get group by name +// See more: https://docs.netbird.io/api/resources/groups#list-all-groups +func (a *GroupsAPI) GetByName(ctx context.Context, groupName string) (*api.Group, error) { + params := map[string]string{"name": groupName} + resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil, params) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.Group](resp) + if err != nil { + return nil, err + } + if len(ret) == 0 { + return nil, ErrGroupNotFound + } + return &ret[0], nil +} + // Get get group info // See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) { diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 4a5454002..2d063a7b5 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3362,6 +3362,14 @@ paths: security: - BearerAuth: [ ] - TokenAuth: [ ] + parameters: + - in: query + name: name + required: false + schema: + type: string + description: Filter groups by name (exact match) + example: "devs" responses: '200': description: A JSON Array of Groups @@ -3375,6 +3383,8 @@ paths: "$ref": "#/components/responses/bad_request" '401': "$ref": "#/components/responses/requires_authentication" + '404': + "$ref": "#/components/responses/not_found" '403': "$ref": "#/components/responses/forbidden" '500': diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 9611d26d6..d3e425548 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1908,6 +1908,12 @@ type GetApiEventsNetworkTrafficParamsConnectionType string // GetApiEventsNetworkTrafficParamsDirection defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParamsDirection string +// GetApiGroupsParams defines parameters for GetApiGroups. +type GetApiGroupsParams struct { + // Name Filter groups by name (exact match) + Name *string `form:"name,omitempty" json:"name,omitempty"` +} + // GetApiPeersParams defines parameters for GetApiPeers. type GetApiPeersParams struct { // Name Filter peers by name From 52948ccd617eb8e95ac9860865129f9bf1be7af1 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 2 Dec 2025 14:17:59 +0300 Subject: [PATCH 099/120] [management] Add user created activity event (#4893) --- management/server/activity/codes.go | 2 ++ management/server/user.go | 24 +++++++++++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 5c5989f84..2e3be1ef5 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -179,6 +179,7 @@ const ( PeerIPUpdated Activity = 88 UserApproved Activity = 89 UserRejected Activity = 90 + UserCreated Activity = 91 AccountDeleted Activity = 99999 ) @@ -288,6 +289,7 @@ var activityMap = map[Activity]Code{ PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, UserApproved: {"User approved", "user.approve"}, UserRejected: {"User rejected", "user.reject"}, + UserCreated: {"User created", "user.create"}, } // StringCode returns a string code of the activity diff --git a/management/server/user.go b/management/server/user.go index 6b8bcbcad..cefc4d1a5 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -596,9 +596,15 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, isNewUser bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() { var eventsToStore []func() + if isNewUser { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.UserCreated, nil) + }) + } + if oldUser.IsBlocked() != newUser.IsBlocked() { if newUser.IsBlocked() { eventsToStore = append(eventsToStore, func() { @@ -661,7 +667,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists) + oldUser, isNewUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists) if err != nil { return false, nil, nil, nil, err } @@ -716,30 +722,30 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } updateAccountPeers := len(userPeers) > 0 - userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, removedGroups, addedGroups, transaction) + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction) return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil } // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. -func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) { +func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, bool, error) { existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id) if err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if !addIfNotExists { - return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) + return nil, false, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) } update.AccountID = accountID - return update, nil // use all fields from update if addIfNotExists is true + return update, true, nil // use all fields from update if addIfNotExists is true } - return nil, err + return nil, false, err } if existingUser.AccountID != accountID { - return nil, status.Errorf(status.InvalidArgument, "user account ID mismatch") + return nil, false, status.Errorf(status.InvalidArgument, "user account ID mismatch") } - return existingUser, nil + return existingUser, false, nil } func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) { From 7193bd2da7bed767d178c3adaf6f864120f7ca51 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 2 Dec 2025 12:34:28 +0100 Subject: [PATCH 100/120] [management] Refactor network map controller (#4789) --- client/cmd/testutil_test.go | 13 +- client/internal/engine_test.go | 13 +- client/server/server_test.go | 13 +- go.mod | 2 +- go.sum | 4 +- .../network_map/controller/controller.go | 160 ++++++++++------- .../network_map/controller/repository.go | 10 ++ .../controllers/network_map/interface.go | 14 +- .../controllers/network_map/interface_mock.go | 89 ++++++---- .../modules}/peers/ephemeral/interface.go | 5 + .../peers/ephemeral/manager/ephemeral.go | 37 ++-- .../peers/ephemeral/manager/ephemeral_test.go | 115 ++++++++++--- management/internals/modules/peers/manager.go | 162 ++++++++++++++++++ .../modules}/peers/manager_mock.go | 53 ++++++ management/internals/server/boot.go | 30 ++-- management/internals/server/controllers.go | 38 ++-- management/internals/server/modules.go | 25 +-- management/internals/server/server.go | 20 +-- management/internals/shared/grpc/server.go | 114 +++++++----- .../internals/shared/grpc/server_test.go | 8 +- management/internals/shared/grpc/token_mgr.go | 18 +- .../internals/shared/grpc/token_mgr_test.go | 9 +- management/server/account.go | 13 +- management/server/account/manager.go | 2 - management/server/account_test.go | 10 +- management/server/cache/idp.go | 1 + management/server/cache/idp_test.go | 2 +- management/server/cache/store.go | 16 +- management/server/cache/store_test.go | 6 +- management/server/dns_test.go | 4 +- management/server/http/handler.go | 2 +- .../http/handlers/peers/peers_handler.go | 28 +-- .../http/handlers/peers/peers_handler_test.go | 34 +--- .../testing/testing_tools/channel/channel.go | 5 +- management/server/management_proto_test.go | 16 +- management/server/management_test.go | 12 +- management/server/mock_server/account_mock.go | 6 - management/server/nameserver_test.go | 4 +- management/server/peer.go | 36 ++-- management/server/peer_test.go | 11 +- management/server/peers/manager.go | 68 -------- management/server/route_test.go | 4 +- management/server/user.go | 37 ++-- management/server/user_test.go | 29 +++- shared/management/client/client_test.go | 13 +- 45 files changed, 819 insertions(+), 492 deletions(-) rename management/{server => internals/modules}/peers/ephemeral/interface.go (83%) rename management/{server => internals/modules}/peers/ephemeral/manager/ephemeral.go (85%) rename management/{server => internals/modules}/peers/ephemeral/manager/ephemeral_test.go (69%) create mode 100644 management/internals/modules/peers/manager.go rename management/{server => internals/modules}/peers/manager_mock.go (55%) delete mode 100644 management/server/peers/manager.go diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 5477653a2..b9ff35945 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -15,6 +15,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" clientProto "github.com/netbirdio/netbird/client/proto" @@ -24,8 +26,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -116,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatal(err) + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9252ce13e..5ab21e3e1 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -30,11 +30,12 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" @@ -54,7 +55,6 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -1628,14 +1628,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) - networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + return nil, "", err + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index f8592bc7a..5f28a2664 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -17,11 +17,12 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -35,7 +36,6 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -316,14 +316,17 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) - networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + return nil, "", err + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index a72d09dc3..91e587c32 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 011a5f199..98d395ad1 100644 --- a/go.sum +++ b/go.sum @@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 h1:ecs4GMANgObopiy29zMmz2dIdOTJMwezUbrFy+zfSwE= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63/go.mod h1:JIWpjbCgDvZIt45C9vYpikU2gRXeDWrN7SiyGYd3Qrc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 49bb9cef3..022ea774c 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -19,6 +19,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" @@ -42,6 +43,7 @@ type Controller struct { accountManagerMetrics *telemetry.AccountManagerMetrics peersUpdateManager network_map.PeersUpdateManager settingsManager settings.Manager + EphemeralPeersManager ephemeral.Manager accountUpdateLocks sync.Map sendAccountUpdateLocks sync.Map @@ -70,7 +72,7 @@ type bufferUpdate struct { var _ network_map.Controller = (*Controller)(nil) -func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller { +func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller { nMetrics, err := newMetrics(metrics.UpdateChannelMetrics()) if err != nil { log.Fatal(fmt.Errorf("error creating metrics: %w", err)) @@ -99,7 +101,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App dnsDomain: dnsDomain, config: config, - proxyController: proxyController, + proxyController: proxyController, + EphemeralPeersManager: ephemeralPeersManager, holder: types.NewHolder(), expNewNetworkMap: newNetworkMapBuilder, @@ -107,6 +110,31 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App } } +func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) { + peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) + if err != nil { + return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err) + } + + c.EphemeralPeersManager.OnPeerConnected(ctx, peer) + + return c.peersUpdateManager.CreateChannel(ctx, peerID), nil +} + +func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) { + c.peersUpdateManager.CloseChannel(ctx, peerID) + peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err) + return + } + c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer) +} + +func (c *Controller) CountStreams() int { + return c.peersUpdateManager.CountStreams() +} + func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) var ( @@ -366,38 +394,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str return nil } -func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error { - network, err := c.repo.GetAccountNetwork(ctx, accountId) - if err != nil { - return err - } - - peers, err := c.repo.GetAccountPeers(ctx, accountId) - if err != nil { - return err - } - - dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) - c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{ - Update: &proto.SyncResponse{ - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - NetworkMap: &proto.NetworkMap{ - Serial: network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - DNSConfig: &proto.DNSConfig{ - ForwarderPort: dnsFwdPort, - }, - }, - }, - }) - c.peersUpdateManager.CloseChannel(ctx, peerId) - return nil -} - func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if isRequiresApproval { network, err := c.repo.GetAccountNetwork(ctx, accountID) @@ -698,35 +694,83 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t return false, nil } -func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) { - c.UpdatePeerInNetworkMapCache(accountId, peer) - _ = c.bufferSendUpdateAccountPeers(context.Background(), accountId) +func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { + peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs) + if err != nil { + return fmt.Errorf("failed to get peers by ids: %w", err) + } + + for _, peer := range peers { + c.UpdatePeerInNetworkMapCache(accountID, peer) + } + + err = c.bufferSendUpdateAccountPeers(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) + } + + return nil } -func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error { - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } +func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { + for _, peerID := range peerIDs { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } - err = c.onPeerAddedUpdNetworkMapCache(account, peerID) - if err != nil { - return err + err = c.onPeerAddedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } } } return c.bufferSendUpdateAccountPeers(ctx, accountID) } -func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error { - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } - err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) - if err != nil { - return err +func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { + network, err := c.repo.GetAccountNetwork(ctx, accountID) + if err != nil { + return err + } + + peers, err := c.repo.GetAccountPeers(ctx, accountID) + if err != nil { + return err + } + + dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) + for _, peerID := range peerIDs { + c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, + }, + }, + }) + c.peersUpdateManager.CloseChannel(ctx, peerID) + + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) + continue + } + err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) + if err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err) + continue + } } } @@ -778,10 +822,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N return networkMap, nil } -func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) { +func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) { c.peersUpdateManager.CloseChannels(ctx, peerIDs) } - -func (c *Controller) IsConnected(peerID string) bool { - return c.peersUpdateManager.HasChannel(peerID) -} diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go index 44144263b..3ed51a5c3 100644 --- a/management/internals/controllers/network_map/controller/repository.go +++ b/management/internals/controllers/network_map/controller/repository.go @@ -12,6 +12,8 @@ type Repository interface { GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) + GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) + GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) } type repository struct { @@ -37,3 +39,11 @@ func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]* func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { return r.store.GetAccountByPeerID(ctx, peerID) } + +func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) { + return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs) +} + +func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) { + return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) +} diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index 6f893ce79..b1de7d017 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -28,12 +28,12 @@ type Controller interface { GetDNSDomain(settings *types.Settings) string StartWarmup(context.Context) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + CountStreams() int - DeletePeer(ctx context.Context, accountId string, peerId string) error - - OnPeerUpdated(accountId string, peer *nbpeer.Peer) - OnPeerAdded(ctx context.Context, accountID string, peerID string) error - OnPeerDeleted(ctx context.Context, accountID string, peerID string) error - DisconnectPeers(ctx context.Context, peerIDs []string) - IsConnected(peerID string) bool + OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error + OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error + OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error + DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) + OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error) + OnPeerDisconnected(ctx context.Context, accountID string, peerID string) } diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index aaa093e47..5a98eefa8 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -57,30 +57,30 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) } -// DeletePeer mocks base method. -func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error { +// CountStreams mocks base method. +func (m *MockController) CountStreams() int { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId) - ret0, _ := ret[0].(error) + ret := m.ctrl.Call(m, "CountStreams") + ret0, _ := ret[0].(int) return ret0 } -// DeletePeer indicates an expected call of DeletePeer. -func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call { +// CountStreams indicates an expected call of CountStreams. +func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams)) } // DisconnectPeers mocks base method. -func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) { +func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) { m.ctrl.T.Helper() - m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs) + m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs) } // DisconnectPeers indicates an expected call of DisconnectPeers. -func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call { +func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs) } // GetDNSDomain mocks base method. @@ -130,58 +130,73 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) } -// IsConnected mocks base method. -func (m *MockController) IsConnected(peerID string) bool { +// OnPeerConnected mocks base method. +func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsConnected", peerID) - ret0, _ := ret[0].(bool) - return ret0 + ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID) + ret0, _ := ret[0].(chan *UpdateMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// IsConnected indicates an expected call of IsConnected. -func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call { +// OnPeerConnected indicates an expected call of OnPeerConnected. +func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID) } -// OnPeerAdded mocks base method. -func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error { +// OnPeerDisconnected mocks base method. +func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID) + m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID) +} + +// OnPeerDisconnected indicates an expected call of OnPeerDisconnected. +func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID) +} + +// OnPeersAdded mocks base method. +func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs) ret0, _ := ret[0].(error) return ret0 } -// OnPeerAdded indicates an expected call of OnPeerAdded. -func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call { +// OnPeersAdded indicates an expected call of OnPeersAdded. +func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs) } -// OnPeerDeleted mocks base method. -func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error { +// OnPeersDeleted mocks base method. +func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID) + ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs) ret0, _ := ret[0].(error) return ret0 } -// OnPeerDeleted indicates an expected call of OnPeerDeleted. -func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call { +// OnPeersDeleted indicates an expected call of OnPeersDeleted. +func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs) } -// OnPeerUpdated mocks base method. -func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) { +// OnPeersUpdated mocks base method. +func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPeerUpdated", accountId, peer) + ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs) + ret0, _ := ret[0].(error) + return ret0 } -// OnPeerUpdated indicates an expected call of OnPeerUpdated. -func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call { +// OnPeersUpdated indicates an expected call of OnPeersUpdated. +func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs) } // StartWarmup mocks base method. diff --git a/management/server/peers/ephemeral/interface.go b/management/internals/modules/peers/ephemeral/interface.go similarity index 83% rename from management/server/peers/ephemeral/interface.go rename to management/internals/modules/peers/ephemeral/interface.go index a1605b3b9..8fe25435c 100644 --- a/management/server/peers/ephemeral/interface.go +++ b/management/internals/modules/peers/ephemeral/interface.go @@ -2,10 +2,15 @@ package ephemeral import ( "context" + "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) +const ( + EphemeralLifeTime = 10 * time.Minute +) + type Manager interface { LoadInitialPeers(ctx context.Context) Stop() diff --git a/management/server/peers/ephemeral/manager/ephemeral.go b/management/internals/modules/peers/ephemeral/manager/ephemeral.go similarity index 85% rename from management/server/peers/ephemeral/manager/ephemeral.go rename to management/internals/modules/peers/ephemeral/manager/ephemeral.go index 062ba69d2..15119045b 100644 --- a/management/server/peers/ephemeral/manager/ephemeral.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral.go @@ -7,14 +7,15 @@ import ( log "github.com/sirupsen/logrus" - nbAccount "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" ) const ( - ephemeralLifeTime = 10 * time.Minute // cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure. cleanupWindow = 1 * time.Minute ) @@ -33,11 +34,11 @@ type ephemeralPeer struct { // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it // in worst case we will get invalid error message in this manager. -// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted +// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted // automatically. Inactivity means the peer disconnected from the Management server. type EphemeralManager struct { - store store.Store - accountManager nbAccount.Manager + store store.Store + peersManager peers.Manager headPeer *ephemeralPeer tailPeer *ephemeralPeer @@ -49,12 +50,12 @@ type EphemeralManager struct { } // NewEphemeralManager instantiate new EphemeralManager -func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *EphemeralManager { +func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager { return &EphemeralManager{ - store: store, - accountManager: accountManager, + store: store, + peersManager: peersManager, - lifeTime: ephemeralLifeTime, + lifeTime: ephemeral.EphemeralLifeTime, cleanupWindow: cleanupWindow, } } @@ -106,7 +107,7 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee } // OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer -// is inactive it will be deleted after the ephemeralLifeTime period. +// is inactive it will be deleted after the EphemeralLifeTime period. func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) { if !peer.Ephemeral { return @@ -180,20 +181,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { e.peersLock.Unlock() - bufferAccountCall := make(map[string]struct{}) - + peerIDsPerAccount := make(map[string][]string) for id, p := range deletePeers { - log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) + peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id) + } + + for accountID, peerIDs := range peerIDsPerAccount { + log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID) + err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) - } else { - bufferAccountCall[p.accountID] = struct{}{} } } - for accountID := range bufferAccountCall { - e.accountManager.BufferUpdateAccountPeers(ctx, accountID) - } } func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { diff --git a/management/server/peers/ephemeral/manager/ephemeral_test.go b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go similarity index 69% rename from management/server/peers/ephemeral/manager/ephemeral_test.go rename to management/internals/modules/peers/ephemeral/manager/ephemeral_test.go index fc7525c29..9d3ed246a 100644 --- a/management/server/peers/ephemeral/manager/ephemeral_test.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go @@ -7,10 +7,13 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -91,17 +94,27 @@ func TestNewManager(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for ephemeral peers + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) if len(store.account.Peers) != numberOfPeers { @@ -119,19 +132,29 @@ func TestNewManagerPeerConnected(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for ephemeral peers (except the connected one) + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) expected := numberOfPeers + 1 @@ -150,15 +173,25 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for the one disconnected peer + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) for _, v := range store.account.Peers { mgr.OnPeerConnected(context.Background(), v) @@ -166,7 +199,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) expected := numberOfPeers + numberOfEphemeralPeers - 1 @@ -181,25 +214,63 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) { testLifeTime = 1 * time.Second testCleanupWindow = 100 * time.Millisecond ) + + t.Cleanup(func() { + timeNow = time.Now + }) + startTime := time.Now() + timeNow = func() time.Time { + return startTime + } + mockStore := &MockStore{} + account := newAccountWithId(context.Background(), "account", "", "", false) + mockStore.account = account + + wg := &sync.WaitGroup{} + wg.Add(ephemeralPeers) mockAM := &MockAccountManager{ store: mockStore, + wg: wg, } - mockAM.wg = &sync.WaitGroup{} - mockAM.wg.Add(ephemeralPeers) - mgr := NewEphemeralManager(mockStore, mockAM) + + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) + + // Set up expectation that DeletePeers will be called once with all peer IDs + peersManager.EXPECT(). + DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + // Simulate the actual deletion behavior + for _, peerID := range peerIDs { + err := mockAM.DeletePeer(ctx, accountID, peerID, userID) + if err != nil { + return err + } + } + mockAM.BufferUpdateAccountPeers(ctx, accountID) + return nil + }). + Times(1) + + mgr := NewEphemeralManager(mockStore, peersManager) mgr.lifeTime = testLifeTime mgr.cleanupWindow = testCleanupWindow - account := newAccountWithId(context.Background(), "account", "", "", false) - mockStore.account = account + // Add peers and disconnect them at slightly different times (within cleanup window) for i := range ephemeralPeers { p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true} mockStore.account.Peers[p.ID] = p - time.Sleep(testCleanupWindow / ephemeralPeers) mgr.OnPeerDisconnected(context.Background(), p) + startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2)) } - mockAM.wg.Wait() + + // Advance time past the lifetime to trigger cleanup + startTime = startTime.Add(testLifeTime + testCleanupWindow) + + // Wait for all deletions to complete + wg.Wait() + assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime") assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once") assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers") diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go new file mode 100644 index 000000000..e82f19e63 --- /dev/null +++ b/management/internals/modules/peers/manager.go @@ -0,0 +1,162 @@ +package peers + +//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod + +import ( + "context" + "fmt" + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +type Manager interface { + GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) + GetPeerAccountID(ctx context.Context, peerID string) (string, error) + GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) + DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error + SetNetworkMapController(networkMapController network_map.Controller) + SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) + SetAccountManager(accountManager account.Manager) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + integratedPeerValidator integrated_validator.IntegratedValidator + accountManager account.Manager + + networkMapController network_map.Controller +} + +func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + } +} + +func (m *managerImpl) SetNetworkMapController(networkMapController network_map.Controller) { + m.networkMapController = networkMapController +} + +func (m *managerImpl) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) { + m.integratedPeerValidator = integratedPeerValidator +} + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + m.accountManager = accountManager +} + +func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { + allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) +} + +func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { + allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) + } + + return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") +} + +func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { + return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) +} + +func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) +} + +func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return err + } + dnsDomain := m.networkMapController.GetDNSDomain(settings) + + for _, peerID := range peerIDs { + var eventsToStore []func() + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + + if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) { + return nil + } + + if err := transaction.RemovePeerFromAllGroups(ctx, peerID); err != nil { + return fmt.Errorf("failed to remove peer %s from groups", peerID) + } + + if err := m.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil { + return err + } + + peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + for _, rule := range peerPolicyRules { + policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID) + if err != nil { + return err + } + + err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID) + if err != nil { + return err + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + }) + } + + if err = transaction.DeletePeer(ctx, accountID, peerID); err != nil { + return err + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + + return nil + }) + if err != nil { + return err + } + for _, event := range eventsToStore { + event() + } + } + + return nil +} diff --git a/management/server/peers/manager_mock.go b/management/internals/modules/peers/manager_mock.go similarity index 55% rename from management/server/peers/manager_mock.go rename to management/internals/modules/peers/manager_mock.go index 994f8346b..2e3651e88 100644 --- a/management/server/peers/manager_mock.go +++ b/management/internals/modules/peers/manager_mock.go @@ -9,6 +9,9 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + network_map "github.com/netbirdio/netbird/management/internals/controllers/network_map" + account "github.com/netbirdio/netbird/management/server/account" + integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" peer "github.com/netbirdio/netbird/management/server/peer" ) @@ -35,6 +38,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } +// DeletePeers mocks base method. +func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePeers", ctx, accountID, peerIDs, userID, checkConnected) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePeers indicates an expected call of DeletePeers. +func (mr *MockManagerMockRecorder) DeletePeers(ctx, accountID, peerIDs, userID, checkConnected interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeers", reflect.TypeOf((*MockManager)(nil).DeletePeers), ctx, accountID, peerIDs, userID, checkConnected) +} + // GetAllPeers mocks base method. func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { m.ctrl.T.Helper() @@ -94,3 +111,39 @@ func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs) } + +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + +// SetIntegratedPeerValidator mocks base method. +func (m *MockManager) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetIntegratedPeerValidator", integratedPeerValidator) +} + +// SetIntegratedPeerValidator indicates an expected call of SetIntegratedPeerValidator. +func (mr *MockManagerMockRecorder) SetIntegratedPeerValidator(integratedPeerValidator interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntegratedPeerValidator", reflect.TypeOf((*MockManager)(nil).SetIntegratedPeerValidator), integratedPeerValidator) +} + +// SetNetworkMapController mocks base method. +func (m *MockManager) SetNetworkMapController(networkMapController network_map.Controller) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetNetworkMapController", networkMapController) +} + +// SetNetworkMapController indicates an expected call of SetNetworkMapController. +func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController) +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index eadd16c2d..37788e80e 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -57,7 +57,7 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics { func (s *BaseServer) Store() store.Store { return Create(s, func() store.Store { - store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false) + store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) if err != nil { log.Fatalf("failed to create store: %v", err) } @@ -73,17 +73,17 @@ func (s *BaseServer) EventStore() activity.Store { log.Fatalf("failed to initialize integration metrics: %v", err) } - eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics) + eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics) if err != nil { log.Fatalf("failed to initialize event store: %v", err) } - if s.config.DataStoreEncryptionKey != key { - log.WithContext(context.Background()).Infof("update config with activity store key") - s.config.DataStoreEncryptionKey = key - err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config) + if s.Config.DataStoreEncryptionKey != key { + log.WithContext(context.Background()).Infof("update Config with activity store key") + s.Config.DataStoreEncryptionKey = key + err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config) if err != nil { - log.Fatalf("failed to update config with activity store: %v", err) + log.Fatalf("failed to update Config with activity store: %v", err) } } @@ -103,14 +103,14 @@ func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) GRPCServer() *grpc.Server { return Create(s, func() *grpc.Server { - trustedPeers := s.config.ReverseProxy.TrustedPeers + trustedPeers := s.Config.ReverseProxy.TrustedPeers defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")} if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) { log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.") trustedPeers = defaultTrustedPeers } - trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies - trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount + trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies + trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 { log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " + "This is not recommended way to extract X-Forwarded-For. Consider using one of these options.") @@ -128,15 +128,15 @@ func (s *BaseServer) GRPCServer() *grpc.Server { grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor), } - if s.config.HttpConfig.LetsEncryptDomain != "" { - certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain) + if s.Config.HttpConfig.LetsEncryptDomain != "" { + certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { log.Fatalf("failed to create certificate manager: %v", err) } transportCredentials := credentials.NewTLS(certManager.TLSConfig()) gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials)) - } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" { - tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey) + } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" { + tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey) if err != nil { log.Fatalf("cannot load TLS credentials: %v", err) } @@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) + srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create management server: %v", err) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 38ec6fde6..3442c7646 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -9,17 +9,17 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { - return Create(s, func() *update_channel.PeersUpdateManager { + return Create(s, func() network_map.PeersUpdateManager { return update_channel.NewPeersUpdateManager(s.Metrics()) }) } @@ -44,33 +44,37 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller { }) } -func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager { - return Create(s, func() *grpc.TimeBasedAuthSecretsManager { - return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) +func (s *BaseServer) SecretsManager() grpc.SecretsManager { + return Create(s, func() grpc.SecretsManager { + secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager()) + if err != nil { + log.Fatalf("failed to create secrets manager: %v", err) + } + return secretsManager }) } func (s *BaseServer) AuthManager() auth.Manager { return Create(s, func() auth.Manager { return auth.NewManager(s.Store(), - s.config.HttpConfig.AuthIssuer, - s.config.HttpConfig.AuthAudience, - s.config.HttpConfig.AuthKeysLocation, - s.config.HttpConfig.AuthUserIDClaim, - s.config.GetAuthAudiences(), - s.config.HttpConfig.IdpSignKeyRefreshEnabled) + s.Config.HttpConfig.AuthIssuer, + s.Config.HttpConfig.AuthAudience, + s.Config.HttpConfig.AuthKeysLocation, + s.Config.HttpConfig.AuthUserIDClaim, + s.Config.GetAuthAudiences(), + s.Config.HttpConfig.IdpSignKeyRefreshEnabled) }) } func (s *BaseServer) EphemeralManager() ephemeral.Manager { return Create(s, func() ephemeral.Manager { - return manager.NewEphemeralManager(s.Store(), s.AccountManager()) + return manager.NewEphemeralManager(s.Store(), s.PeersManager()) }) } func (s *BaseServer) NetworkMapController() network_map.Controller { - return Create(s, func() *nmapcontroller.Controller { - return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController(), s.config) + return Create(s, func() network_map.Controller { + return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config) }) } @@ -79,3 +83,7 @@ func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { return server.NewAccountRequestBuffer(context.Background(), s.Store()) }) } + +func (s *BaseServer) DNSDomain() string { + return s.dnsDomain +} diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 18a8427be..91ce50a79 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/geolocation" @@ -14,7 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/users" @@ -22,12 +23,12 @@ import ( func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { return Create(s, func() geolocation.Geolocation { - geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate) + geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate) if err != nil { log.Fatalf("could not initialize geolocation service: %v", err) } - log.Infof("geolocation service has been initialized from %s", s.config.Datadir) + log.Infof("geolocation service has been initialized from %s", s.Config.Datadir) return geo }) @@ -60,20 +61,22 @@ func (s *BaseServer) SettingsManager() settings.Manager { func (s *BaseServer) PeersManager() peers.Manager { return Create(s, func() peers.Manager { - return peers.NewManager(s.Store(), s.PermissionsManager()) + manager := peers.NewManager(s.Store(), s.PermissionsManager()) + s.AfterInit(func(s *BaseServer) { + manager.SetNetworkMapController(s.NetworkMapController()) + manager.SetIntegratedPeerValidator(s.IntegratedValidator()) + manager.SetAccountManager(s.AccountManager()) + }) + return manager }) } func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) if err != nil { log.Fatalf("failed to create account manager: %v", err) } - - s.AfterInit(func(s *BaseServer) { - accountManager.SetEphemeralManager(s.EphemeralManager()) - }) return accountManager }) } @@ -82,8 +85,8 @@ func (s *BaseServer) IdpManager() idp.Manager { return Create(s, func() idp.Manager { var idpManager idp.Manager var err error - if s.config.IdpManagerConfig != nil { - idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics()) + if s.Config.IdpManagerConfig != nil { + idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) if err != nil { log.Fatalf("failed to create IDP manager: %v", err) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index ab1c2ebe7..a1b144dac 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -41,10 +41,10 @@ type Server interface { } // Server holds the HTTP BaseServer instance. -// Add any additional fields you need, such as database connections, config, etc. +// Add any additional fields you need, such as database connections, Config, etc. type BaseServer struct { - // config holds the server configuration - config *nbconfig.Config + // Config holds the server configuration + Config *nbconfig.Config // container of dependencies, each dependency is identified by a unique string. container map[string]any // AfterInit is a function that will be called after the server is initialized @@ -70,7 +70,7 @@ type BaseServer struct { // NewServer initializes and configures a new Server instance func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer { return &BaseServer{ - config: config, + Config: config, container: make(map[string]any), dnsDomain: dnsDomain, mgmtSingleAccModeDomain: mgmtSingleAccModeDomain, @@ -103,14 +103,14 @@ func (s *BaseServer) Start(ctx context.Context) error { var tlsConfig *tls.Config tlsEnabled := false - if s.config.HttpConfig.LetsEncryptDomain != "" { - s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain) + if s.Config.HttpConfig.LetsEncryptDomain != "" { + s.certManager, err = encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err) } tlsEnabled = true - } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" { - tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey) + } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" { + tlsConfig, err = loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey) if err != nil { log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err) return err @@ -126,8 +126,8 @@ func (s *BaseServer) Start(ctx context.Context) error { if !s.disableMetrics { idpManager := "disabled" - if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" { - idpManager = s.config.IdpManagerConfig.ManagerType + if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" { + idpManager = s.Config.IdpManagerConfig.ManagerType } metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager) go metricsWorker.Run(srvCtx) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 0aadadf84..62dc215d8 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/store" @@ -55,15 +54,12 @@ const ( type Server struct { accountManager account.Manager settingsManager settings.Manager - wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager network_map.PeersUpdateManager - config *nbconfig.Config - secretsManager SecretsManager - appMetrics telemetry.AppMetrics - ephemeralManager ephemeral.Manager - peerLocks sync.Map - authManager auth.Manager + config *nbconfig.Config + secretsManager SecretsManager + appMetrics telemetry.AppMetrics + peerLocks sync.Map + authManager auth.Manager logBlockedPeers bool blockPeersWithSameConfig bool @@ -82,23 +78,16 @@ func NewServer( config *nbconfig.Config, accountManager account.Manager, settingsManager settings.Manager, - peersUpdateManager network_map.PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, - ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, networkMapController network_map.Controller, ) (*Server, error) { - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, err - } - if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams - err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { - return int64(peersUpdateManager.CountStreams()) + err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { + return int64(networkMapController.CountStreams()) }) if err != nil { return nil, err @@ -120,16 +109,12 @@ func NewServer( } return &Server{ - wgKey: key, - // peerKey -> event channel - peersUpdateManager: peersUpdateManager, accountManager: accountManager, settingsManager: settingsManager, config: config, secretsManager: secretsManager, authManager: authManager, appMetrics: appMetrics, - ephemeralManager: ephemeralManager, logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, @@ -163,8 +148,14 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser nanos := int32(now.Nanosecond()) expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos} + key, err := s.secretsManager.GetWGKey() + if err != nil { + log.WithContext(ctx).Errorf("failed to get wireguard key: %v", err) + return nil, errors.New("failed to get wireguard key") + } + return &proto.ServerKeyResponse{ - Key: s.wgKey.PublicKey().String(), + Key: key.PublicKey().String(), ExpiresAt: expiresAt, }, nil } @@ -269,9 +260,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S return err } - updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) - - s.ephemeralManager.OnPeerConnected(ctx, peer) + updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) + s.cancelPeerRoutines(ctx, accountID, peer) + return err + } s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) @@ -323,13 +318,19 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) + key, err := s.secretsManager.GetWGKey() + if err != nil { + s.cancelPeerRoutines(ctx, accountID, peer) + return status.Errorf(codes.Internal, "failed processing update message") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update) if err != nil { s.cancelPeerRoutines(ctx, accountID, peer) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.SendMsg(&proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) if err != nil { @@ -348,9 +349,8 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) } - s.peersUpdateManager.CloseChannel(ctx, peer.ID) + s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - s.ephemeralManager.OnPeerDisconnected(ctx, peer) log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) } @@ -504,7 +504,12 @@ func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return wgtypes.Key{}, status.Errorf(codes.Internal, "failed processing request") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, parsed) if err != nil { return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message") } @@ -601,12 +606,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart)) - // if the login request contains setup key then it is a registration request - if loginReq.GetSetupKey() != "" { - s.ephemeralManager.OnPeerDisconnected(ctx, peer) - log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart)) - } - loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) if err != nil { log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) @@ -615,14 +614,20 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart)) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) + key, err := s.secretsManager.GetWGKey() + if err != nil { + log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err) + return nil, status.Errorf(codes.Internal, "failed logging in peer") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) return nil, status.Errorf(codes.Internal, "failed logging in peer") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } @@ -715,14 +720,19 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return status.Errorf(codes.Internal, "failed getting server key") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, plainResp) if err != nil { return status.Errorf(codes.Internal, "error handling request") } sendStart := time.Now() err = srv.Send(&proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart)) @@ -752,7 +762,12 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr return nil, status.Error(codes.InvalidArgument, errMSG) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{}) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get server key") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.DeviceAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) log.WithContext(ctx).Warn(errMSG) @@ -782,13 +797,13 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr }, } - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } @@ -810,7 +825,12 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp return nil, status.Error(codes.InvalidArgument, errMSG) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{}) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get server key") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.PKCEAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) log.WithContext(ctx).Warn(errMSG) @@ -838,13 +858,13 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } diff --git a/management/internals/shared/grpc/server_test.go b/management/internals/shared/grpc/server_test.go index 9867b38e3..d3a12e986 100644 --- a/management/internals/shared/grpc/server_test.go +++ b/management/internals/shared/grpc/server_test.go @@ -73,15 +73,17 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { mgmtServer := &Server{ - wgKey: testingServerKey, + secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey}, config: &config.Config{ DeviceAuthorizationFlow: testCase.inputFlow, }, } message := &mgmtProto.DeviceAuthorizationFlowRequest{} + key, err := mgmtServer.secretsManager.GetWGKey() + require.NoError(t, err, "should be able to get server key") - encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) + encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message) require.NoError(t, err, "should be able to encrypt message") resp, err := mgmtServer.GetDeviceAuthorizationFlow( @@ -95,7 +97,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { if testCase.expectedComparisonFunc != nil { flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} - err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) + err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp) require.NoError(t, err, "should be able to decrypt") testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) diff --git a/management/internals/shared/grpc/token_mgr.go b/management/internals/shared/grpc/token_mgr.go index e9770db41..0f893ae3a 100644 --- a/management/internals/shared/grpc/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -10,6 +10,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" "github.com/netbirdio/netbird/management/internals/controllers/network_map" @@ -29,6 +30,7 @@ type SecretsManager interface { GenerateRelayToken() (*Token, error) SetupRefresh(ctx context.Context, accountID, peerKey string) CancelRefresh(peerKey string) + GetWGKey() (wgtypes.Key, error) } // TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server @@ -43,11 +45,17 @@ type TimeBasedAuthSecretsManager struct { groupsManager groups.Manager turnCancelMap map[string]chan struct{} relayCancelMap map[string]chan struct{} + wgKey wgtypes.Key } type Token auth.Token -func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { +func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + mgr := &TimeBasedAuthSecretsManager{ updateManager: updateManager, turnCfg: turnCfg, @@ -56,6 +64,7 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager relayCancelMap: make(map[string]chan struct{}), settingsManager: settingsManager, groupsManager: groupsManager, + wgKey: key, } if turnCfg != nil { @@ -81,7 +90,12 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager } } - return mgr + return mgr, nil +} + +// GetWGKey returns WireGuard private key used to generate peer keys +func (m *TimeBasedAuthSecretsManager) GetWGKey() (wgtypes.Key, error) { + return m.wgKey, nil } // GenerateTurnToken generates new time-based secret credentials for TURN diff --git a/management/internals/shared/grpc/token_mgr_test.go b/management/internals/shared/grpc/token_mgr_test.go index 06d28d05b..98eb66fb5 100644 --- a/management/internals/shared/grpc/token_mgr_test.go +++ b/management/internals/shared/grpc/token_mgr_test.go @@ -46,12 +46,13 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) turnCredentials, err := tested.GenerateTurnToken() require.NoError(t, err) @@ -98,12 +99,13 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -201,12 +203,13 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) tested.SetupRefresh(context.Background(), "someAccountID", peer) if _, ok := tested.turnCancelMap[peer]; !ok { diff --git a/management/server/account.go b/management/server/account.go index 716d5ab5d..dac040db0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -37,7 +37,6 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -77,7 +76,6 @@ type DefaultAccountManager struct { ctx context.Context eventStore activity.Store geo geolocation.Geolocation - ephemeralManager ephemeral.Manager requestBuffer *AccountRequestBuffer @@ -238,7 +236,7 @@ func BuildManager( log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter) } - cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) if err != nil { return nil, fmt.Errorf("getting cache store: %s", err) } @@ -263,10 +261,6 @@ func BuildManager( return am, nil } -func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { - am.ephemeralManager = em -} - func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager { return am.externalCacheManager } @@ -2076,7 +2070,10 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us if err != nil { return err } - am.networkMapController.OnPeerUpdated(peer.AccountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID}) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) + } } return nil } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 9b3902d87..b5921ec7a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -13,7 +13,6 @@ import ( nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -124,5 +123,4 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) - SetEphemeralManager(em ephemeral.Manager) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 340e8db18..8569f1b2f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -25,6 +25,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" @@ -2959,8 +2961,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) - manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, nil, err } @@ -3371,7 +3373,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { t.Run("memory cache", func(t *testing.T) { t.Run("should always return true", func(t *testing.T) { - cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) require.NoError(t, err) cold, err := manager.isCacheCold(context.Background(), cacheStore) @@ -3386,7 +3388,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) - cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) require.NoError(t, err) t.Run("should return true when no account exists", func(t *testing.T) { diff --git a/management/server/cache/idp.go b/management/server/cache/idp.go index 1b31ff82a..19dfc0f38 100644 --- a/management/server/cache/idp.go +++ b/management/server/cache/idp.go @@ -18,6 +18,7 @@ const ( DefaultIDPCacheExpirationMax = 7 * 24 * time.Hour // 7 days DefaultIDPCacheExpirationMin = 3 * 24 * time.Hour // 3 days DefaultIDPCacheCleanupInterval = 30 * time.Minute + DefaultIDPCacheOpenConn = 100 ) // UserDataCache is an interface that wraps the basic Get, Set and Delete methods for idp.UserData objects. diff --git a/management/server/cache/idp_test.go b/management/server/cache/idp_test.go index 3fcfbb11a..0e8061e94 100644 --- a/management/server/cache/idp_test.go +++ b/management/server/cache/idp_test.go @@ -33,7 +33,7 @@ func TestNewIDPCacheManagers(t *testing.T) { t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) } - cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval) + cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval, cache.DefaultIDPCacheOpenConn) if err != nil { t.Fatalf("couldn't create cache store: %s", err) } diff --git a/management/server/cache/store.go b/management/server/cache/store.go index 1c141a180..54b0242de 100644 --- a/management/server/cache/store.go +++ b/management/server/cache/store.go @@ -3,6 +3,7 @@ package cache import ( "context" "fmt" + "math" "os" "time" @@ -20,24 +21,27 @@ const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" // NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar // to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store. -func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration) (store.StoreInterface, error) { +func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) { redisAddr := os.Getenv(RedisStoreEnvVar) if redisAddr != "" { - return getRedisStore(ctx, redisAddr) + return getRedisStore(ctx, redisAddr, maxConn) } goc := gocache.New(maxTimeout, cleanupInterval) return gocache_store.NewGoCache(goc), nil } -func getRedisStore(ctx context.Context, redisEnvAddr string) (store.StoreInterface, error) { +func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) { options, err := redis.ParseURL(redisEnvAddr) if err != nil { return nil, fmt.Errorf("parsing redis cache url: %s", err) } - options.MaxIdleConns = 6 - options.MinIdleConns = 3 - options.MaxActiveConns = 100 + options.MaxIdleConns = int(math.Ceil(float64(maxConn) * 0.5)) // 50% of max conns + options.MinIdleConns = int(math.Ceil(float64(maxConn) * 0.1)) // 10% of max conns + options.MaxActiveConns = maxConn + options.ConnMaxIdleTime = 30 * time.Minute + options.ConnMaxLifetime = 0 + options.PoolTimeout = 10 * time.Second redisClient := redis.NewClient(options) subCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() diff --git a/management/server/cache/store_test.go b/management/server/cache/store_test.go index f49dd6bbd..1b64fd70d 100644 --- a/management/server/cache/store_test.go +++ b/management/server/cache/store_test.go @@ -15,7 +15,7 @@ import ( ) func TestMemoryStore(t *testing.T) { - memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { t.Fatalf("couldn't create memory store: %s", err) } @@ -42,7 +42,7 @@ func TestMemoryStore(t *testing.T) { func TestRedisStoreConnectionFailure(t *testing.T) { t.Setenv(cache.RedisStoreEnvVar, "redis://127.0.0.1:6379") - _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond) + _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond, 100) if err == nil { t.Fatal("getting redis cache store should return error") } @@ -65,7 +65,7 @@ func TestRedisStoreConnectionSuccess(t *testing.T) { } t.Setenv(cache.RedisStoreEnvVar, redisURL) - redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { t.Fatalf("couldn't create redis store: %s", err) } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 99b09566a..b5e3f2b99 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -12,6 +12,8 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" @@ -223,7 +225,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c1a8c5885..7cf0b5765 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" + nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" nbgroups "github.com/netbirdio/netbird/management/server/groups" @@ -39,7 +40,6 @@ import ( nbnetworks "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - nbpeers "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/telemetry" ) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index c4c5ae165..f531f0cdb 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -45,19 +45,6 @@ func NewHandler(accountManager account.Manager, networkMapController network_map } } -func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { - peerToReturn := peer.Copy() - if peer.Status.Connected { - // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected - // This may happen after server restart when not all peers are yet connected - if !h.networkMapController.IsConnected(peer.ID) { - peerToReturn.Status.Connected = false - } - } - - return peerToReturn, nil -} - func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { @@ -65,11 +52,6 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, return } - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(ctx, err, w) - return - } settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) if err != nil { util.WriteError(ctx, err, w) @@ -91,7 +73,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, _, valid := validPeers[peer.ID] reason := invalidPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -237,13 +219,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers)) respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) + respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0)) } validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index ddf2e2a70..55e779ff0 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -109,14 +109,6 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { GetDNSDomain(gomock.Any()). Return("domain"). AnyTimes() - networkMapController.EXPECT(). - IsConnected(noUpdateChannelTestPeerID). - Return(false). - AnyTimes() - networkMapController.EXPECT(). - IsConnected(gomock.Any()). - Return(true). - AnyTimes() return &Handler{ accountManager: &mock_server.MockAccountManager{ @@ -269,14 +261,6 @@ func TestGetPeers(t *testing.T) { expectedArray: false, expectedPeer: peer, }, - { - name: "GetPeer with no update channel", - requestType: http.MethodGet, - requestPath: "/api/peers/" + peer1.ID, - expectedStatus: http.StatusOK, - expectedArray: false, - expectedPeer: expectedPeer1, - }, { name: "PutPeer", requestType: http.MethodPut, @@ -336,8 +320,6 @@ func TestGetPeers(t *testing.T) { for _, peer := range respBody { if peer.Id == testPeerID { got = peer - } else { - assert.Equal(t, peer.Connected, false) } } @@ -351,14 +333,14 @@ func TestGetPeers(t *testing.T) { t.Log(got) - assert.Equal(t, got.Name, tc.expectedPeer.Name) - assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion) - assert.Equal(t, got.Ip, tc.expectedPeer.IP.String()) - assert.Equal(t, got.Os, "OS core") - assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled) - assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled) - assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected) - assert.Equal(t, got.SerialNumber, tc.expectedPeer.Meta.SystemSerialNumber) + assert.Equal(t, tc.expectedPeer.Name, got.Name) + assert.Equal(t, tc.expectedPeer.Meta.WtVersion, got.Version) + assert.Equal(t, tc.expectedPeer.IP.String(), got.Ip) + assert.Equal(t, "OS core", got.Os) + assert.Equal(t, tc.expectedPeer.LoginExpirationEnabled, got.LoginExpirationEnabled) + assert.Equal(t, tc.expectedPeer.SSHEnabled, got.SshEnabled) + assert.Equal(t, tc.expectedPeer.Status.Connected, got.Connected) + assert.Equal(t, tc.expectedPeer.Meta.SystemSerialNumber, got.SerialNumber) }) } } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index e292a7d6c..e8513feb5 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -15,6 +15,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server" @@ -28,7 +30,6 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -72,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee ctx := context.Background() requestBuffer := server.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 42311d944..42f192c0a 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -24,13 +24,14 @@ import ( "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -363,7 +364,9 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)) + + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config) accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) @@ -372,10 +375,13 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + cleanup() + return nil, nil, "", cleanup, err + } - ephemeralMgr := manager.NewEphemeralManager(store, accountManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 2350b225b..648201d4e 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -22,13 +22,14 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -205,7 +206,7 @@ func startServer( ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) - networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config) accountManager, err := server.BuildManager( context.Background(), @@ -228,15 +229,16 @@ func startServer( } groupsManager := groups.NewManager(str, permissionsManager, accountManager) - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatalf("failed creating secrets manager: %v", err) + } mgmtServer, err := nbgrpc.NewServer( config, accountManager, settingsMockManager, - updateManager, secretsManager, nil, - &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, networkMapController, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 0178e51f5..928098dbe 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -15,7 +15,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -976,11 +975,6 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth a return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } -// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface -func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) { - // Mock implementation - does nothing -} - func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { if am.AllowSyncFunc != nil { return am.AllowSyncFunc(key, hash) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index a4574c978..e3dd8b0b8 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -13,6 +13,8 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -792,7 +794,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/peer.go b/management/server/peer.go index cd9fbe4c8..f2de05f15 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -136,7 +136,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) + } } return nil @@ -309,7 +312,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, fmt.Errorf("notify network map controller of peer update: %w", err) + } return peer, nil } @@ -365,13 +371,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) - } - - if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) + if err := am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err) } return nil @@ -583,11 +584,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return fmt.Errorf("failed adding peer to All group: %w", err) } - if temporary { - // we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually - am.ephemeralManager.OnPeerDisconnected(ctx, newPeer) - } - if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -645,7 +641,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil { + if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil { log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) } @@ -729,7 +725,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err) + } } return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) @@ -857,7 +856,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err) + } } p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 2d09f5200..752563299 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -28,6 +28,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" @@ -1058,6 +1060,7 @@ func testUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel + assert.Nil(t, update.Update.NetbirdConfig) assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } @@ -1290,7 +1293,7 @@ func Test_RegisterPeerByUser(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1375,7 +1378,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1528,7 +1531,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1608,7 +1611,7 @@ func Test_LoginPeer(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go deleted file mode 100644 index cb135f4ac..000000000 --- a/management/server/peers/manager.go +++ /dev/null @@ -1,68 +0,0 @@ -package peers - -//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod - -import ( - "context" - "fmt" - - "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" -) - -type Manager interface { - GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) - GetPeerAccountID(ctx context.Context, peerID string) (string, error) - GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) - GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) -} - -type managerImpl struct { - store store.Store - permissionsManager permissions.Manager -} - -func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { - return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - } -} - -func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) -} - -func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) - } - - return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") -} - -func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { - return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) -} - -func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { - return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) -} diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c8b636bc..a413d545b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -16,6 +16,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -1291,7 +1293,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { diff --git a/management/server/user.go b/management/server/user.go index cefc4d1a5..ca02f91e6 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -263,15 +263,11 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return err } - updateAccountPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) + _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) if err != nil { return err } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) - } - return nil } @@ -998,14 +994,17 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) + } - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) } if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) - am.networkMapController.DisconnectPeers(ctx, peerIDs) + am.networkMapController.DisconnectPeers(ctx, accountID, peerIDs) } return nil } @@ -1051,7 +1050,6 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } var allErrors error - var updateAccountPeers bool for _, targetUserID := range targetUserIDs { if initiatorUserID == targetUserID { @@ -1082,19 +1080,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - userHadPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) + _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) if err != nil { allErrors = errors.Join(allErrors, err) continue } - - if userHadPeers { - updateAccountPeers = true - } - } - - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) } return allErrors @@ -1152,15 +1142,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return false, err } + var peerIDs []string for _, peer := range userPeers { - err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) - } - - if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err) - } + peerIDs = append(peerIDs, peer.ID) + } + if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil { + log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err) } for _, addPeerRemovedEvent := range addPeerRemovedEvents { diff --git a/management/server/user_test.go b/management/server/user_test.go index 5ce15621e..0d778cfa2 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -8,8 +8,10 @@ import ( "time" "github.com/google/go-cmp/cmp" + "go.uber.org/mock/gomock" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -547,7 +549,7 @@ func TestUser_InviteNewUser(t *testing.T) { permissionsManager: permissionsManager, } - cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) require.NoError(t, err) am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cs) @@ -739,11 +741,18 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + ctrl := gomock.NewController(t) + networkMapControllerMock := network_map.NewMockController(ctrl) + networkMapControllerMock.EXPECT(). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil) + permissionsManager := permissions.NewManager(store) am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, - permissionsManager: permissionsManager, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + networkMapController: networkMapControllerMock, } testCases := []struct { @@ -848,12 +857,20 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + ctrl := gomock.NewController(t) + networkMapControllerMock := network_map.NewMockController(ctrl) + networkMapControllerMock.EXPECT(). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + permissionsManager := permissions.NewManager(store) am := DefaultAccountManager{ Store: store, eventStore: &activity.InMemoryEventStore{}, integratedPeerValidator: MockIntegratedValidator{}, permissionsManager: permissionsManager, + networkMapController: networkMapControllerMock, } testCases := []struct { @@ -1056,7 +1073,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { permissionsManager: permissionsManager, } - cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) assert.NoError(t, err) am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) @@ -1412,7 +1429,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { t.Run("deleting user with no linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index 9e08317f6..9fbe70948 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -21,6 +21,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/client/system" @@ -31,8 +33,6 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -117,7 +117,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManger), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) @@ -125,8 +125,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { groupsManager := groups.NewManagerMock() - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatal(err) + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } From 10e9cf8c62ec7acef474cf96fa65283f2e20941a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 2 Dec 2025 14:13:01 +0100 Subject: [PATCH 101/120] [management] update management integrations (#4895) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 91e587c32..870118c88 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 98d395ad1..96a303f79 100644 --- a/go.sum +++ b/go.sum @@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 h1:ecs4GMANgObopiy29zMmz2dIdOTJMwezUbrFy+zfSwE= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63/go.mod h1:JIWpjbCgDvZIt45C9vYpikU2gRXeDWrN7SiyGYd3Qrc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= From a293f760af653afb0afe7635680c0ff5a7527128 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 2 Dec 2025 16:30:15 +0100 Subject: [PATCH 102/120] [client] Add conditional peer removal logic during shutdown (#4897) --- client/internal/engine.go | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 0ff1006cd..84bb8beea 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -292,6 +292,12 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") + if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { + log.Info("removing peers before dns") + if err := e.removeAllPeers(); err != nil { + return fmt.Errorf("failed to remove all peers: %s", err) + } + } if err := e.stopSSHServer(); err != nil { log.Warnf("failed to stop SSH server: %v", err) } @@ -310,6 +316,13 @@ func (e *Engine) Stop() error { e.stopDNSForwarder() + if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" { + log.Info("removing peers before routes") + if err := e.removeAllPeers(); err != nil { + return fmt.Errorf("failed to remove all peers: %s", err) + } + } + if e.routeManager != nil { e.routeManager.Stop(e.stateManager) } @@ -317,13 +330,16 @@ func (e *Engine) Stop() error { if e.srWatcher != nil { e.srWatcher.Close() } - + log.Info("cleaning up status recorder states") e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) - if err := e.removeAllPeers(); err != nil { - return fmt.Errorf("failed to remove all peers: %s", err) + if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { + log.Info("removing peers after dns and routes") + if err := e.removeAllPeers(); err != nil { + return fmt.Errorf("failed to remove all peers: %s", err) + } } if e.cancel != nil { From a232cf614c33a58b9b7c3c522b855273696889c8 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 2 Dec 2025 18:31:59 +0100 Subject: [PATCH 103/120] [management] record pat usage metrics (#4888) --- management/server/http/handler.go | 1 + .../server/http/middleware/auth_middleware.go | 17 ++++ .../http/middleware/auth_middleware_test.go | 7 ++ .../http/middleware/pat_usage_tracker.go | 87 +++++++++++++++++++ 4 files changed, 112 insertions(+) create mode 100644 management/server/http/middleware/pat_usage_tracker.go diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 7cf0b5765..b7c6c113c 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -105,6 +105,7 @@ func NewAPIHandler( accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, rateLimitingConfig, + appMetrics.GetMeter(), ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 9439165a4..ffd7e0b93 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -9,6 +9,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -31,6 +32,7 @@ type AuthMiddleware struct { getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc rateLimiter *APIRateLimiter + patUsageTracker *PATUsageTracker } // NewAuthMiddleware instance constructor @@ -40,18 +42,29 @@ func NewAuthMiddleware( syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, rateLimiterConfig *RateLimiterConfig, + meter metric.Meter, ) *AuthMiddleware { var rateLimiter *APIRateLimiter if rateLimiterConfig != nil { rateLimiter = NewAPIRateLimiter(rateLimiterConfig) } + var patUsageTracker *PATUsageTracker + if meter != nil { + var err error + patUsageTracker, err = NewPATUsageTracker(context.Background(), meter) + if err != nil { + log.Errorf("Failed to create PAT usage tracker: %s", err) + } + } + return &AuthMiddleware{ authManager: authManager, ensureAccount: ensureAccount, syncUserJWTGroups: syncUserJWTGroups, getUserFromUserAuth: getUserFromUserAuth, rateLimiter: rateLimiter, + patUsageTracker: patUsageTracker, } } @@ -158,6 +171,10 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] return r, fmt.Errorf("error extracting token: %w", err) } + if m.patUsageTracker != nil { + m.patUsageTracker.IncrementUsage(token) + } + if m.rateLimiter != nil { if !m.rateLimiter.Allow(token) { return r, status.Errorf(status.TooManyRequests, "too many requests") diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 7badc03e4..ba4d16796 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -208,6 +208,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { return &types.User{}, nil }, nil, + nil, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -266,6 +267,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -317,6 +319,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -359,6 +362,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -402,6 +406,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -465,6 +470,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -581,6 +587,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { return &types.User{}, nil }, nil, + nil, ) for _, tc := range tt { diff --git a/management/server/http/middleware/pat_usage_tracker.go b/management/server/http/middleware/pat_usage_tracker.go new file mode 100644 index 000000000..331c288e7 --- /dev/null +++ b/management/server/http/middleware/pat_usage_tracker.go @@ -0,0 +1,87 @@ +package middleware + +import ( + "context" + "maps" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" +) + +// PATUsageTracker tracks PAT usage metrics +type PATUsageTracker struct { + usageCounters map[string]int64 + mu sync.Mutex + stopChan chan struct{} + ctx context.Context + histogram metric.Int64Histogram +} + +// NewPATUsageTracker creates a new PAT usage tracker with metrics +func NewPATUsageTracker(ctx context.Context, meter metric.Meter) (*PATUsageTracker, error) { + histogram, err := meter.Int64Histogram( + "management.pat.usage_distribution", + metric.WithUnit("1"), + metric.WithDescription("Distribution of PAT token usage counts per minute"), + ) + if err != nil { + return nil, err + } + + tracker := &PATUsageTracker{ + usageCounters: make(map[string]int64), + stopChan: make(chan struct{}), + ctx: ctx, + histogram: histogram, + } + + go tracker.reportLoop() + + return tracker, nil +} + +// IncrementUsage increments the usage counter for a given token +func (t *PATUsageTracker) IncrementUsage(token string) { + t.mu.Lock() + defer t.mu.Unlock() + t.usageCounters[token]++ +} + +// reportLoop reports the usage buckets every minute +func (t *PATUsageTracker) reportLoop() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + t.reportUsageBuckets() + case <-t.stopChan: + return + } + } +} + +// reportUsageBuckets reports all token usage counts and resets counters +func (t *PATUsageTracker) reportUsageBuckets() { + t.mu.Lock() + snapshot := maps.Clone(t.usageCounters) + + clear(t.usageCounters) + t.mu.Unlock() + + totalTokens := len(snapshot) + if totalTokens > 0 { + for _, count := range snapshot { + t.histogram.Record(t.ctx, count) + } + log.Debugf("PAT usage in last minute: %d unique tokens used", totalTokens) + } +} + +// Stop stops the reporting goroutine +func (t *PATUsageTracker) Stop() { + close(t.stopChan) +} From e87b4ace11988e9dec38d5f8ccf5c4334a9fac05 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 3 Dec 2025 11:53:39 +0100 Subject: [PATCH 104/120] [client] Add sleep state tracking to handle wakeup/sleep events reliably (#4894) Adds a new NotifyOSLifecycle RPC and server handler to centralize OS sleep/wake handling, introduces Server.sleepTriggeredDown for coordination, updates client UI to call the new RPC, and adjusts the internal sleep event enum zero-value semantics. --- client/internal/sleep/service.go | 5 +- client/proto/daemon.pb.go | 1019 +++++++++++++++++------------- client/proto/daemon.proto | 19 +- client/proto/daemon_grpc.pb.go | 40 +- client/server/lifecycle.go | 77 +++ client/server/lifecycle_test.go | 219 +++++++ client/server/server.go | 3 + client/ui/client_ui.go | 26 +- 8 files changed, 954 insertions(+), 454 deletions(-) create mode 100644 client/server/lifecycle.go create mode 100644 client/server/lifecycle_test.go diff --git a/client/internal/sleep/service.go b/client/internal/sleep/service.go index 35fc933c0..196a33f52 100644 --- a/client/internal/sleep/service.go +++ b/client/internal/sleep/service.go @@ -1,8 +1,9 @@ package sleep var ( - EventTypeSleep EventType = 0 - EventTypeWakeUp EventType = 1 + EventTypeUnknown EventType = 0 + EventTypeSleep EventType = 1 + EventTypeWakeUp EventType = 2 ) type EventType int diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 6f8255615..28e8b2d4e 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v3.21.12 +// protoc v6.32.1 // source: daemon.proto package proto @@ -88,6 +88,56 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{0} } +// avoid collision with loglevel enum +type OSLifecycleRequest_CycleType int32 + +const ( + OSLifecycleRequest_UNKNOWN OSLifecycleRequest_CycleType = 0 + OSLifecycleRequest_SLEEP OSLifecycleRequest_CycleType = 1 + OSLifecycleRequest_WAKEUP OSLifecycleRequest_CycleType = 2 +) + +// Enum value maps for OSLifecycleRequest_CycleType. +var ( + OSLifecycleRequest_CycleType_name = map[int32]string{ + 0: "UNKNOWN", + 1: "SLEEP", + 2: "WAKEUP", + } + OSLifecycleRequest_CycleType_value = map[string]int32{ + "UNKNOWN": 0, + "SLEEP": 1, + "WAKEUP": 2, + } +) + +func (x OSLifecycleRequest_CycleType) Enum() *OSLifecycleRequest_CycleType { + p := new(OSLifecycleRequest_CycleType) + *p = x + return p +} + +func (x OSLifecycleRequest_CycleType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor { + return file_daemon_proto_enumTypes[1].Descriptor() +} + +func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType { + return &file_daemon_proto_enumTypes[1] +} + +func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use OSLifecycleRequest_CycleType.Descriptor instead. +func (OSLifecycleRequest_CycleType) EnumDescriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{1, 0} +} + type SystemEvent_Severity int32 const ( @@ -124,11 +174,11 @@ func (x SystemEvent_Severity) String() string { } func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[1].Descriptor() + return file_daemon_proto_enumTypes[2].Descriptor() } func (SystemEvent_Severity) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[1] + return &file_daemon_proto_enumTypes[2] } func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { @@ -137,7 +187,7 @@ func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Severity.Descriptor instead. func (SystemEvent_Severity) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51, 0} + return file_daemon_proto_rawDescGZIP(), []int{53, 0} } type SystemEvent_Category int32 @@ -179,11 +229,11 @@ func (x SystemEvent_Category) String() string { } func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor { - return file_daemon_proto_enumTypes[2].Descriptor() + return file_daemon_proto_enumTypes[3].Descriptor() } func (SystemEvent_Category) Type() protoreflect.EnumType { - return &file_daemon_proto_enumTypes[2] + return &file_daemon_proto_enumTypes[3] } func (x SystemEvent_Category) Number() protoreflect.EnumNumber { @@ -192,7 +242,7 @@ func (x SystemEvent_Category) Number() protoreflect.EnumNumber { // Deprecated: Use SystemEvent_Category.Descriptor instead. func (SystemEvent_Category) EnumDescriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51, 1} + return file_daemon_proto_rawDescGZIP(), []int{53, 1} } type EmptyRequest struct { @@ -231,6 +281,86 @@ func (*EmptyRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{0} } +type OSLifecycleRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Type OSLifecycleRequest_CycleType `protobuf:"varint,1,opt,name=type,proto3,enum=daemon.OSLifecycleRequest_CycleType" json:"type,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OSLifecycleRequest) Reset() { + *x = OSLifecycleRequest{} + mi := &file_daemon_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OSLifecycleRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OSLifecycleRequest) ProtoMessage() {} + +func (x *OSLifecycleRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OSLifecycleRequest.ProtoReflect.Descriptor instead. +func (*OSLifecycleRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{1} +} + +func (x *OSLifecycleRequest) GetType() OSLifecycleRequest_CycleType { + if x != nil { + return x.Type + } + return OSLifecycleRequest_UNKNOWN +} + +type OSLifecycleResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OSLifecycleResponse) Reset() { + *x = OSLifecycleResponse{} + mi := &file_daemon_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OSLifecycleResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OSLifecycleResponse) ProtoMessage() {} + +func (x *OSLifecycleResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OSLifecycleResponse.ProtoReflect.Descriptor instead. +func (*OSLifecycleResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{2} +} + type LoginRequest struct { state protoimpl.MessageState `protogen:"open.v1"` // setupKey netbird setup key. @@ -293,7 +423,7 @@ type LoginRequest struct { func (x *LoginRequest) Reset() { *x = LoginRequest{} - mi := &file_daemon_proto_msgTypes[1] + mi := &file_daemon_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -305,7 +435,7 @@ func (x *LoginRequest) String() string { func (*LoginRequest) ProtoMessage() {} func (x *LoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[1] + mi := &file_daemon_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -318,7 +448,7 @@ func (x *LoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginRequest.ProtoReflect.Descriptor instead. func (*LoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{1} + return file_daemon_proto_rawDescGZIP(), []int{3} } func (x *LoginRequest) GetSetupKey() string { @@ -607,7 +737,7 @@ type LoginResponse struct { func (x *LoginResponse) Reset() { *x = LoginResponse{} - mi := &file_daemon_proto_msgTypes[2] + mi := &file_daemon_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -619,7 +749,7 @@ func (x *LoginResponse) String() string { func (*LoginResponse) ProtoMessage() {} func (x *LoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[2] + mi := &file_daemon_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -632,7 +762,7 @@ func (x *LoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LoginResponse.ProtoReflect.Descriptor instead. func (*LoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{2} + return file_daemon_proto_rawDescGZIP(), []int{4} } func (x *LoginResponse) GetNeedsSSOLogin() bool { @@ -673,7 +803,7 @@ type WaitSSOLoginRequest struct { func (x *WaitSSOLoginRequest) Reset() { *x = WaitSSOLoginRequest{} - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -685,7 +815,7 @@ func (x *WaitSSOLoginRequest) String() string { func (*WaitSSOLoginRequest) ProtoMessage() {} func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[3] + mi := &file_daemon_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -698,7 +828,7 @@ func (x *WaitSSOLoginRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginRequest.ProtoReflect.Descriptor instead. func (*WaitSSOLoginRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{3} + return file_daemon_proto_rawDescGZIP(), []int{5} } func (x *WaitSSOLoginRequest) GetUserCode() string { @@ -724,7 +854,7 @@ type WaitSSOLoginResponse struct { func (x *WaitSSOLoginResponse) Reset() { *x = WaitSSOLoginResponse{} - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -736,7 +866,7 @@ func (x *WaitSSOLoginResponse) String() string { func (*WaitSSOLoginResponse) ProtoMessage() {} func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[4] + mi := &file_daemon_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -749,7 +879,7 @@ func (x *WaitSSOLoginResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitSSOLoginResponse.ProtoReflect.Descriptor instead. func (*WaitSSOLoginResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{4} + return file_daemon_proto_rawDescGZIP(), []int{6} } func (x *WaitSSOLoginResponse) GetEmail() string { @@ -769,7 +899,7 @@ type UpRequest struct { func (x *UpRequest) Reset() { *x = UpRequest{} - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -781,7 +911,7 @@ func (x *UpRequest) String() string { func (*UpRequest) ProtoMessage() {} func (x *UpRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[5] + mi := &file_daemon_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -794,7 +924,7 @@ func (x *UpRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use UpRequest.ProtoReflect.Descriptor instead. func (*UpRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{5} + return file_daemon_proto_rawDescGZIP(), []int{7} } func (x *UpRequest) GetProfileName() string { @@ -819,7 +949,7 @@ type UpResponse struct { func (x *UpResponse) Reset() { *x = UpResponse{} - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -831,7 +961,7 @@ func (x *UpResponse) String() string { func (*UpResponse) ProtoMessage() {} func (x *UpResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[6] + mi := &file_daemon_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -844,7 +974,7 @@ func (x *UpResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use UpResponse.ProtoReflect.Descriptor instead. func (*UpResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{6} + return file_daemon_proto_rawDescGZIP(), []int{8} } type StatusRequest struct { @@ -859,7 +989,7 @@ type StatusRequest struct { func (x *StatusRequest) Reset() { *x = StatusRequest{} - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -871,7 +1001,7 @@ func (x *StatusRequest) String() string { func (*StatusRequest) ProtoMessage() {} func (x *StatusRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[7] + mi := &file_daemon_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -884,7 +1014,7 @@ func (x *StatusRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusRequest.ProtoReflect.Descriptor instead. func (*StatusRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{7} + return file_daemon_proto_rawDescGZIP(), []int{9} } func (x *StatusRequest) GetGetFullPeerStatus() bool { @@ -921,7 +1051,7 @@ type StatusResponse struct { func (x *StatusResponse) Reset() { *x = StatusResponse{} - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -933,7 +1063,7 @@ func (x *StatusResponse) String() string { func (*StatusResponse) ProtoMessage() {} func (x *StatusResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[8] + mi := &file_daemon_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -946,7 +1076,7 @@ func (x *StatusResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead. func (*StatusResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{8} + return file_daemon_proto_rawDescGZIP(), []int{10} } func (x *StatusResponse) GetStatus() string { @@ -978,7 +1108,7 @@ type DownRequest struct { func (x *DownRequest) Reset() { *x = DownRequest{} - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -990,7 +1120,7 @@ func (x *DownRequest) String() string { func (*DownRequest) ProtoMessage() {} func (x *DownRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[9] + mi := &file_daemon_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1003,7 +1133,7 @@ func (x *DownRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DownRequest.ProtoReflect.Descriptor instead. func (*DownRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{9} + return file_daemon_proto_rawDescGZIP(), []int{11} } type DownResponse struct { @@ -1014,7 +1144,7 @@ type DownResponse struct { func (x *DownResponse) Reset() { *x = DownResponse{} - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1026,7 +1156,7 @@ func (x *DownResponse) String() string { func (*DownResponse) ProtoMessage() {} func (x *DownResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[10] + mi := &file_daemon_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1039,7 +1169,7 @@ func (x *DownResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DownResponse.ProtoReflect.Descriptor instead. func (*DownResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{10} + return file_daemon_proto_rawDescGZIP(), []int{12} } type GetConfigRequest struct { @@ -1052,7 +1182,7 @@ type GetConfigRequest struct { func (x *GetConfigRequest) Reset() { *x = GetConfigRequest{} - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1064,7 +1194,7 @@ func (x *GetConfigRequest) String() string { func (*GetConfigRequest) ProtoMessage() {} func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[11] + mi := &file_daemon_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1077,7 +1207,7 @@ func (x *GetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigRequest.ProtoReflect.Descriptor instead. func (*GetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{11} + return file_daemon_proto_rawDescGZIP(), []int{13} } func (x *GetConfigRequest) GetProfileName() string { @@ -1133,7 +1263,7 @@ type GetConfigResponse struct { func (x *GetConfigResponse) Reset() { *x = GetConfigResponse{} - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1145,7 +1275,7 @@ func (x *GetConfigResponse) String() string { func (*GetConfigResponse) ProtoMessage() {} func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[12] + mi := &file_daemon_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1158,7 +1288,7 @@ func (x *GetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetConfigResponse.ProtoReflect.Descriptor instead. func (*GetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{12} + return file_daemon_proto_rawDescGZIP(), []int{14} } func (x *GetConfigResponse) GetManagementUrl() string { @@ -1370,7 +1500,7 @@ type PeerState struct { func (x *PeerState) Reset() { *x = PeerState{} - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1382,7 +1512,7 @@ func (x *PeerState) String() string { func (*PeerState) ProtoMessage() {} func (x *PeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[13] + mi := &file_daemon_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1395,7 +1525,7 @@ func (x *PeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use PeerState.ProtoReflect.Descriptor instead. func (*PeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{13} + return file_daemon_proto_rawDescGZIP(), []int{15} } func (x *PeerState) GetIP() string { @@ -1540,7 +1670,7 @@ type LocalPeerState struct { func (x *LocalPeerState) Reset() { *x = LocalPeerState{} - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1552,7 +1682,7 @@ func (x *LocalPeerState) String() string { func (*LocalPeerState) ProtoMessage() {} func (x *LocalPeerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[14] + mi := &file_daemon_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1565,7 +1695,7 @@ func (x *LocalPeerState) ProtoReflect() protoreflect.Message { // Deprecated: Use LocalPeerState.ProtoReflect.Descriptor instead. func (*LocalPeerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{14} + return file_daemon_proto_rawDescGZIP(), []int{16} } func (x *LocalPeerState) GetIP() string { @@ -1629,7 +1759,7 @@ type SignalState struct { func (x *SignalState) Reset() { *x = SignalState{} - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1641,7 +1771,7 @@ func (x *SignalState) String() string { func (*SignalState) ProtoMessage() {} func (x *SignalState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[15] + mi := &file_daemon_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1654,7 +1784,7 @@ func (x *SignalState) ProtoReflect() protoreflect.Message { // Deprecated: Use SignalState.ProtoReflect.Descriptor instead. func (*SignalState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{15} + return file_daemon_proto_rawDescGZIP(), []int{17} } func (x *SignalState) GetURL() string { @@ -1690,7 +1820,7 @@ type ManagementState struct { func (x *ManagementState) Reset() { *x = ManagementState{} - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1702,7 +1832,7 @@ func (x *ManagementState) String() string { func (*ManagementState) ProtoMessage() {} func (x *ManagementState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[16] + mi := &file_daemon_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1715,7 +1845,7 @@ func (x *ManagementState) ProtoReflect() protoreflect.Message { // Deprecated: Use ManagementState.ProtoReflect.Descriptor instead. func (*ManagementState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{16} + return file_daemon_proto_rawDescGZIP(), []int{18} } func (x *ManagementState) GetURL() string { @@ -1751,7 +1881,7 @@ type RelayState struct { func (x *RelayState) Reset() { *x = RelayState{} - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1763,7 +1893,7 @@ func (x *RelayState) String() string { func (*RelayState) ProtoMessage() {} func (x *RelayState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[17] + mi := &file_daemon_proto_msgTypes[19] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1776,7 +1906,7 @@ func (x *RelayState) ProtoReflect() protoreflect.Message { // Deprecated: Use RelayState.ProtoReflect.Descriptor instead. func (*RelayState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{17} + return file_daemon_proto_rawDescGZIP(), []int{19} } func (x *RelayState) GetURI() string { @@ -1812,7 +1942,7 @@ type NSGroupState struct { func (x *NSGroupState) Reset() { *x = NSGroupState{} - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1824,7 +1954,7 @@ func (x *NSGroupState) String() string { func (*NSGroupState) ProtoMessage() {} func (x *NSGroupState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[18] + mi := &file_daemon_proto_msgTypes[20] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1837,7 +1967,7 @@ func (x *NSGroupState) ProtoReflect() protoreflect.Message { // Deprecated: Use NSGroupState.ProtoReflect.Descriptor instead. func (*NSGroupState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{18} + return file_daemon_proto_rawDescGZIP(), []int{20} } func (x *NSGroupState) GetServers() []string { @@ -1881,7 +2011,7 @@ type SSHSessionInfo struct { func (x *SSHSessionInfo) Reset() { *x = SSHSessionInfo{} - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1893,7 +2023,7 @@ func (x *SSHSessionInfo) String() string { func (*SSHSessionInfo) ProtoMessage() {} func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[19] + mi := &file_daemon_proto_msgTypes[21] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1906,7 +2036,7 @@ func (x *SSHSessionInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHSessionInfo.ProtoReflect.Descriptor instead. func (*SSHSessionInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{19} + return file_daemon_proto_rawDescGZIP(), []int{21} } func (x *SSHSessionInfo) GetUsername() string { @@ -1948,7 +2078,7 @@ type SSHServerState struct { func (x *SSHServerState) Reset() { *x = SSHServerState{} - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1960,7 +2090,7 @@ func (x *SSHServerState) String() string { func (*SSHServerState) ProtoMessage() {} func (x *SSHServerState) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[20] + mi := &file_daemon_proto_msgTypes[22] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1973,7 +2103,7 @@ func (x *SSHServerState) ProtoReflect() protoreflect.Message { // Deprecated: Use SSHServerState.ProtoReflect.Descriptor instead. func (*SSHServerState) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{20} + return file_daemon_proto_rawDescGZIP(), []int{22} } func (x *SSHServerState) GetEnabled() bool { @@ -2009,7 +2139,7 @@ type FullStatus struct { func (x *FullStatus) Reset() { *x = FullStatus{} - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2021,7 +2151,7 @@ func (x *FullStatus) String() string { func (*FullStatus) ProtoMessage() {} func (x *FullStatus) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[21] + mi := &file_daemon_proto_msgTypes[23] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2034,7 +2164,7 @@ func (x *FullStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use FullStatus.ProtoReflect.Descriptor instead. func (*FullStatus) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{21} + return file_daemon_proto_rawDescGZIP(), []int{23} } func (x *FullStatus) GetManagementState() *ManagementState { @@ -2116,7 +2246,7 @@ type ListNetworksRequest struct { func (x *ListNetworksRequest) Reset() { *x = ListNetworksRequest{} - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2128,7 +2258,7 @@ func (x *ListNetworksRequest) String() string { func (*ListNetworksRequest) ProtoMessage() {} func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[22] + mi := &file_daemon_proto_msgTypes[24] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2141,7 +2271,7 @@ func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. func (*ListNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{22} + return file_daemon_proto_rawDescGZIP(), []int{24} } type ListNetworksResponse struct { @@ -2153,7 +2283,7 @@ type ListNetworksResponse struct { func (x *ListNetworksResponse) Reset() { *x = ListNetworksResponse{} - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2165,7 +2295,7 @@ func (x *ListNetworksResponse) String() string { func (*ListNetworksResponse) ProtoMessage() {} func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[23] + mi := &file_daemon_proto_msgTypes[25] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2178,7 +2308,7 @@ func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. func (*ListNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{23} + return file_daemon_proto_rawDescGZIP(), []int{25} } func (x *ListNetworksResponse) GetRoutes() []*Network { @@ -2199,7 +2329,7 @@ type SelectNetworksRequest struct { func (x *SelectNetworksRequest) Reset() { *x = SelectNetworksRequest{} - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2211,7 +2341,7 @@ func (x *SelectNetworksRequest) String() string { func (*SelectNetworksRequest) ProtoMessage() {} func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[24] + mi := &file_daemon_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2224,7 +2354,7 @@ func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{24} + return file_daemon_proto_rawDescGZIP(), []int{26} } func (x *SelectNetworksRequest) GetNetworkIDs() []string { @@ -2256,7 +2386,7 @@ type SelectNetworksResponse struct { func (x *SelectNetworksResponse) Reset() { *x = SelectNetworksResponse{} - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2268,7 +2398,7 @@ func (x *SelectNetworksResponse) String() string { func (*SelectNetworksResponse) ProtoMessage() {} func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[25] + mi := &file_daemon_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2281,7 +2411,7 @@ func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{25} + return file_daemon_proto_rawDescGZIP(), []int{27} } type IPList struct { @@ -2293,7 +2423,7 @@ type IPList struct { func (x *IPList) Reset() { *x = IPList{} - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2305,7 +2435,7 @@ func (x *IPList) String() string { func (*IPList) ProtoMessage() {} func (x *IPList) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[26] + mi := &file_daemon_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2318,7 +2448,7 @@ func (x *IPList) ProtoReflect() protoreflect.Message { // Deprecated: Use IPList.ProtoReflect.Descriptor instead. func (*IPList) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{26} + return file_daemon_proto_rawDescGZIP(), []int{28} } func (x *IPList) GetIps() []string { @@ -2341,7 +2471,7 @@ type Network struct { func (x *Network) Reset() { *x = Network{} - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2353,7 +2483,7 @@ func (x *Network) String() string { func (*Network) ProtoMessage() {} func (x *Network) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[27] + mi := &file_daemon_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2366,7 +2496,7 @@ func (x *Network) ProtoReflect() protoreflect.Message { // Deprecated: Use Network.ProtoReflect.Descriptor instead. func (*Network) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{27} + return file_daemon_proto_rawDescGZIP(), []int{29} } func (x *Network) GetID() string { @@ -2418,7 +2548,7 @@ type PortInfo struct { func (x *PortInfo) Reset() { *x = PortInfo{} - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2430,7 +2560,7 @@ func (x *PortInfo) String() string { func (*PortInfo) ProtoMessage() {} func (x *PortInfo) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[28] + mi := &file_daemon_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2443,7 +2573,7 @@ func (x *PortInfo) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. func (*PortInfo) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28} + return file_daemon_proto_rawDescGZIP(), []int{30} } func (x *PortInfo) GetPortSelection() isPortInfo_PortSelection { @@ -2500,7 +2630,7 @@ type ForwardingRule struct { func (x *ForwardingRule) Reset() { *x = ForwardingRule{} - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2512,7 +2642,7 @@ func (x *ForwardingRule) String() string { func (*ForwardingRule) ProtoMessage() {} func (x *ForwardingRule) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[29] + mi := &file_daemon_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2525,7 +2655,7 @@ func (x *ForwardingRule) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRule.ProtoReflect.Descriptor instead. func (*ForwardingRule) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{29} + return file_daemon_proto_rawDescGZIP(), []int{31} } func (x *ForwardingRule) GetProtocol() string { @@ -2572,7 +2702,7 @@ type ForwardingRulesResponse struct { func (x *ForwardingRulesResponse) Reset() { *x = ForwardingRulesResponse{} - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2584,7 +2714,7 @@ func (x *ForwardingRulesResponse) String() string { func (*ForwardingRulesResponse) ProtoMessage() {} func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[30] + mi := &file_daemon_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2597,7 +2727,7 @@ func (x *ForwardingRulesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ForwardingRulesResponse.ProtoReflect.Descriptor instead. func (*ForwardingRulesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{30} + return file_daemon_proto_rawDescGZIP(), []int{32} } func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule { @@ -2621,7 +2751,7 @@ type DebugBundleRequest struct { func (x *DebugBundleRequest) Reset() { *x = DebugBundleRequest{} - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2633,7 +2763,7 @@ func (x *DebugBundleRequest) String() string { func (*DebugBundleRequest) ProtoMessage() {} func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2646,7 +2776,7 @@ func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead. func (*DebugBundleRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{31} + return file_daemon_proto_rawDescGZIP(), []int{33} } func (x *DebugBundleRequest) GetAnonymize() bool { @@ -2695,7 +2825,7 @@ type DebugBundleResponse struct { func (x *DebugBundleResponse) Reset() { *x = DebugBundleResponse{} - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2707,7 +2837,7 @@ func (x *DebugBundleResponse) String() string { func (*DebugBundleResponse) ProtoMessage() {} func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2720,7 +2850,7 @@ func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead. func (*DebugBundleResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{32} + return file_daemon_proto_rawDescGZIP(), []int{34} } func (x *DebugBundleResponse) GetPath() string { @@ -2752,7 +2882,7 @@ type GetLogLevelRequest struct { func (x *GetLogLevelRequest) Reset() { *x = GetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2764,7 +2894,7 @@ func (x *GetLogLevelRequest) String() string { func (*GetLogLevelRequest) ProtoMessage() {} func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[33] + mi := &file_daemon_proto_msgTypes[35] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2777,7 +2907,7 @@ func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead. func (*GetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{33} + return file_daemon_proto_rawDescGZIP(), []int{35} } type GetLogLevelResponse struct { @@ -2789,7 +2919,7 @@ type GetLogLevelResponse struct { func (x *GetLogLevelResponse) Reset() { *x = GetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2801,7 +2931,7 @@ func (x *GetLogLevelResponse) String() string { func (*GetLogLevelResponse) ProtoMessage() {} func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[34] + mi := &file_daemon_proto_msgTypes[36] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2814,7 +2944,7 @@ func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead. func (*GetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{34} + return file_daemon_proto_rawDescGZIP(), []int{36} } func (x *GetLogLevelResponse) GetLevel() LogLevel { @@ -2833,7 +2963,7 @@ type SetLogLevelRequest struct { func (x *SetLogLevelRequest) Reset() { *x = SetLogLevelRequest{} - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2845,7 +2975,7 @@ func (x *SetLogLevelRequest) String() string { func (*SetLogLevelRequest) ProtoMessage() {} func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[35] + mi := &file_daemon_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2858,7 +2988,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead. func (*SetLogLevelRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{35} + return file_daemon_proto_rawDescGZIP(), []int{37} } func (x *SetLogLevelRequest) GetLevel() LogLevel { @@ -2876,7 +3006,7 @@ type SetLogLevelResponse struct { func (x *SetLogLevelResponse) Reset() { *x = SetLogLevelResponse{} - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2888,7 +3018,7 @@ func (x *SetLogLevelResponse) String() string { func (*SetLogLevelResponse) ProtoMessage() {} func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[36] + mi := &file_daemon_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2901,7 +3031,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead. func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{36} + return file_daemon_proto_rawDescGZIP(), []int{38} } // State represents a daemon state entry @@ -2914,7 +3044,7 @@ type State struct { func (x *State) Reset() { *x = State{} - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2926,7 +3056,7 @@ func (x *State) String() string { func (*State) ProtoMessage() {} func (x *State) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[37] + mi := &file_daemon_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2939,7 +3069,7 @@ func (x *State) ProtoReflect() protoreflect.Message { // Deprecated: Use State.ProtoReflect.Descriptor instead. func (*State) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{37} + return file_daemon_proto_rawDescGZIP(), []int{39} } func (x *State) GetName() string { @@ -2958,7 +3088,7 @@ type ListStatesRequest struct { func (x *ListStatesRequest) Reset() { *x = ListStatesRequest{} - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2970,7 +3100,7 @@ func (x *ListStatesRequest) String() string { func (*ListStatesRequest) ProtoMessage() {} func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[38] + mi := &file_daemon_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2983,7 +3113,7 @@ func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead. func (*ListStatesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{38} + return file_daemon_proto_rawDescGZIP(), []int{40} } // ListStatesResponse contains a list of states @@ -2996,7 +3126,7 @@ type ListStatesResponse struct { func (x *ListStatesResponse) Reset() { *x = ListStatesResponse{} - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3008,7 +3138,7 @@ func (x *ListStatesResponse) String() string { func (*ListStatesResponse) ProtoMessage() {} func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[39] + mi := &file_daemon_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3021,7 +3151,7 @@ func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead. func (*ListStatesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{39} + return file_daemon_proto_rawDescGZIP(), []int{41} } func (x *ListStatesResponse) GetStates() []*State { @@ -3042,7 +3172,7 @@ type CleanStateRequest struct { func (x *CleanStateRequest) Reset() { *x = CleanStateRequest{} - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3054,7 +3184,7 @@ func (x *CleanStateRequest) String() string { func (*CleanStateRequest) ProtoMessage() {} func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[40] + mi := &file_daemon_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3067,7 +3197,7 @@ func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead. func (*CleanStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{40} + return file_daemon_proto_rawDescGZIP(), []int{42} } func (x *CleanStateRequest) GetStateName() string { @@ -3094,7 +3224,7 @@ type CleanStateResponse struct { func (x *CleanStateResponse) Reset() { *x = CleanStateResponse{} - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3106,7 +3236,7 @@ func (x *CleanStateResponse) String() string { func (*CleanStateResponse) ProtoMessage() {} func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[41] + mi := &file_daemon_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3119,7 +3249,7 @@ func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead. func (*CleanStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{41} + return file_daemon_proto_rawDescGZIP(), []int{43} } func (x *CleanStateResponse) GetCleanedStates() int32 { @@ -3140,7 +3270,7 @@ type DeleteStateRequest struct { func (x *DeleteStateRequest) Reset() { *x = DeleteStateRequest{} - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3152,7 +3282,7 @@ func (x *DeleteStateRequest) String() string { func (*DeleteStateRequest) ProtoMessage() {} func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[42] + mi := &file_daemon_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3165,7 +3295,7 @@ func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead. func (*DeleteStateRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{42} + return file_daemon_proto_rawDescGZIP(), []int{44} } func (x *DeleteStateRequest) GetStateName() string { @@ -3192,7 +3322,7 @@ type DeleteStateResponse struct { func (x *DeleteStateResponse) Reset() { *x = DeleteStateResponse{} - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[45] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3204,7 +3334,7 @@ func (x *DeleteStateResponse) String() string { func (*DeleteStateResponse) ProtoMessage() {} func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[43] + mi := &file_daemon_proto_msgTypes[45] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3217,7 +3347,7 @@ func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead. func (*DeleteStateResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{43} + return file_daemon_proto_rawDescGZIP(), []int{45} } func (x *DeleteStateResponse) GetDeletedStates() int32 { @@ -3236,7 +3366,7 @@ type SetSyncResponsePersistenceRequest struct { func (x *SetSyncResponsePersistenceRequest) Reset() { *x = SetSyncResponsePersistenceRequest{} - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[46] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3248,7 +3378,7 @@ func (x *SetSyncResponsePersistenceRequest) String() string { func (*SetSyncResponsePersistenceRequest) ProtoMessage() {} func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[44] + mi := &file_daemon_proto_msgTypes[46] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3261,7 +3391,7 @@ func (x *SetSyncResponsePersistenceRequest) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceRequest.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{44} + return file_daemon_proto_rawDescGZIP(), []int{46} } func (x *SetSyncResponsePersistenceRequest) GetEnabled() bool { @@ -3279,7 +3409,7 @@ type SetSyncResponsePersistenceResponse struct { func (x *SetSyncResponsePersistenceResponse) Reset() { *x = SetSyncResponsePersistenceResponse{} - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[47] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3291,7 +3421,7 @@ func (x *SetSyncResponsePersistenceResponse) String() string { func (*SetSyncResponsePersistenceResponse) ProtoMessage() {} func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[45] + mi := &file_daemon_proto_msgTypes[47] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3304,7 +3434,7 @@ func (x *SetSyncResponsePersistenceResponse) ProtoReflect() protoreflect.Message // Deprecated: Use SetSyncResponsePersistenceResponse.ProtoReflect.Descriptor instead. func (*SetSyncResponsePersistenceResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{45} + return file_daemon_proto_rawDescGZIP(), []int{47} } type TCPFlags struct { @@ -3321,7 +3451,7 @@ type TCPFlags struct { func (x *TCPFlags) Reset() { *x = TCPFlags{} - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[48] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3333,7 +3463,7 @@ func (x *TCPFlags) String() string { func (*TCPFlags) ProtoMessage() {} func (x *TCPFlags) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[46] + mi := &file_daemon_proto_msgTypes[48] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3346,7 +3476,7 @@ func (x *TCPFlags) ProtoReflect() protoreflect.Message { // Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead. func (*TCPFlags) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{46} + return file_daemon_proto_rawDescGZIP(), []int{48} } func (x *TCPFlags) GetSyn() bool { @@ -3408,7 +3538,7 @@ type TracePacketRequest struct { func (x *TracePacketRequest) Reset() { *x = TracePacketRequest{} - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[49] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3420,7 +3550,7 @@ func (x *TracePacketRequest) String() string { func (*TracePacketRequest) ProtoMessage() {} func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[47] + mi := &file_daemon_proto_msgTypes[49] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3433,7 +3563,7 @@ func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead. func (*TracePacketRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{47} + return file_daemon_proto_rawDescGZIP(), []int{49} } func (x *TracePacketRequest) GetSourceIp() string { @@ -3511,7 +3641,7 @@ type TraceStage struct { func (x *TraceStage) Reset() { *x = TraceStage{} - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[50] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3523,7 +3653,7 @@ func (x *TraceStage) String() string { func (*TraceStage) ProtoMessage() {} func (x *TraceStage) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[48] + mi := &file_daemon_proto_msgTypes[50] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3536,7 +3666,7 @@ func (x *TraceStage) ProtoReflect() protoreflect.Message { // Deprecated: Use TraceStage.ProtoReflect.Descriptor instead. func (*TraceStage) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{48} + return file_daemon_proto_rawDescGZIP(), []int{50} } func (x *TraceStage) GetName() string { @@ -3577,7 +3707,7 @@ type TracePacketResponse struct { func (x *TracePacketResponse) Reset() { *x = TracePacketResponse{} - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[51] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3589,7 +3719,7 @@ func (x *TracePacketResponse) String() string { func (*TracePacketResponse) ProtoMessage() {} func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[49] + mi := &file_daemon_proto_msgTypes[51] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3602,7 +3732,7 @@ func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead. func (*TracePacketResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{49} + return file_daemon_proto_rawDescGZIP(), []int{51} } func (x *TracePacketResponse) GetStages() []*TraceStage { @@ -3627,7 +3757,7 @@ type SubscribeRequest struct { func (x *SubscribeRequest) Reset() { *x = SubscribeRequest{} - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[52] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3639,7 +3769,7 @@ func (x *SubscribeRequest) String() string { func (*SubscribeRequest) ProtoMessage() {} func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[50] + mi := &file_daemon_proto_msgTypes[52] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3652,7 +3782,7 @@ func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. func (*SubscribeRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{50} + return file_daemon_proto_rawDescGZIP(), []int{52} } type SystemEvent struct { @@ -3670,7 +3800,7 @@ type SystemEvent struct { func (x *SystemEvent) Reset() { *x = SystemEvent{} - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[53] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3682,7 +3812,7 @@ func (x *SystemEvent) String() string { func (*SystemEvent) ProtoMessage() {} func (x *SystemEvent) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[51] + mi := &file_daemon_proto_msgTypes[53] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3695,7 +3825,7 @@ func (x *SystemEvent) ProtoReflect() protoreflect.Message { // Deprecated: Use SystemEvent.ProtoReflect.Descriptor instead. func (*SystemEvent) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{51} + return file_daemon_proto_rawDescGZIP(), []int{53} } func (x *SystemEvent) GetId() string { @@ -3755,7 +3885,7 @@ type GetEventsRequest struct { func (x *GetEventsRequest) Reset() { *x = GetEventsRequest{} - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[54] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3767,7 +3897,7 @@ func (x *GetEventsRequest) String() string { func (*GetEventsRequest) ProtoMessage() {} func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[52] + mi := &file_daemon_proto_msgTypes[54] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3780,7 +3910,7 @@ func (x *GetEventsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsRequest.ProtoReflect.Descriptor instead. func (*GetEventsRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{52} + return file_daemon_proto_rawDescGZIP(), []int{54} } type GetEventsResponse struct { @@ -3792,7 +3922,7 @@ type GetEventsResponse struct { func (x *GetEventsResponse) Reset() { *x = GetEventsResponse{} - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[55] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3804,7 +3934,7 @@ func (x *GetEventsResponse) String() string { func (*GetEventsResponse) ProtoMessage() {} func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[53] + mi := &file_daemon_proto_msgTypes[55] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3817,7 +3947,7 @@ func (x *GetEventsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetEventsResponse.ProtoReflect.Descriptor instead. func (*GetEventsResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{53} + return file_daemon_proto_rawDescGZIP(), []int{55} } func (x *GetEventsResponse) GetEvents() []*SystemEvent { @@ -3837,7 +3967,7 @@ type SwitchProfileRequest struct { func (x *SwitchProfileRequest) Reset() { *x = SwitchProfileRequest{} - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[56] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3849,7 +3979,7 @@ func (x *SwitchProfileRequest) String() string { func (*SwitchProfileRequest) ProtoMessage() {} func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[54] + mi := &file_daemon_proto_msgTypes[56] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3862,7 +3992,7 @@ func (x *SwitchProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileRequest.ProtoReflect.Descriptor instead. func (*SwitchProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{54} + return file_daemon_proto_rawDescGZIP(), []int{56} } func (x *SwitchProfileRequest) GetProfileName() string { @@ -3887,7 +4017,7 @@ type SwitchProfileResponse struct { func (x *SwitchProfileResponse) Reset() { *x = SwitchProfileResponse{} - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[57] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3899,7 +4029,7 @@ func (x *SwitchProfileResponse) String() string { func (*SwitchProfileResponse) ProtoMessage() {} func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[55] + mi := &file_daemon_proto_msgTypes[57] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3912,7 +4042,7 @@ func (x *SwitchProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SwitchProfileResponse.ProtoReflect.Descriptor instead. func (*SwitchProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{55} + return file_daemon_proto_rawDescGZIP(), []int{57} } type SetConfigRequest struct { @@ -3960,7 +4090,7 @@ type SetConfigRequest struct { func (x *SetConfigRequest) Reset() { *x = SetConfigRequest{} - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[58] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3972,7 +4102,7 @@ func (x *SetConfigRequest) String() string { func (*SetConfigRequest) ProtoMessage() {} func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[56] + mi := &file_daemon_proto_msgTypes[58] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3985,7 +4115,7 @@ func (x *SetConfigRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigRequest.ProtoReflect.Descriptor instead. func (*SetConfigRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{56} + return file_daemon_proto_rawDescGZIP(), []int{58} } func (x *SetConfigRequest) GetUsername() string { @@ -4234,7 +4364,7 @@ type SetConfigResponse struct { func (x *SetConfigResponse) Reset() { *x = SetConfigResponse{} - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[59] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4246,7 +4376,7 @@ func (x *SetConfigResponse) String() string { func (*SetConfigResponse) ProtoMessage() {} func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[57] + mi := &file_daemon_proto_msgTypes[59] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4259,7 +4389,7 @@ func (x *SetConfigResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetConfigResponse.ProtoReflect.Descriptor instead. func (*SetConfigResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{57} + return file_daemon_proto_rawDescGZIP(), []int{59} } type AddProfileRequest struct { @@ -4272,7 +4402,7 @@ type AddProfileRequest struct { func (x *AddProfileRequest) Reset() { *x = AddProfileRequest{} - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[60] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4284,7 +4414,7 @@ func (x *AddProfileRequest) String() string { func (*AddProfileRequest) ProtoMessage() {} func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[58] + mi := &file_daemon_proto_msgTypes[60] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4297,7 +4427,7 @@ func (x *AddProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileRequest.ProtoReflect.Descriptor instead. func (*AddProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{58} + return file_daemon_proto_rawDescGZIP(), []int{60} } func (x *AddProfileRequest) GetUsername() string { @@ -4322,7 +4452,7 @@ type AddProfileResponse struct { func (x *AddProfileResponse) Reset() { *x = AddProfileResponse{} - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[61] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4334,7 +4464,7 @@ func (x *AddProfileResponse) String() string { func (*AddProfileResponse) ProtoMessage() {} func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[59] + mi := &file_daemon_proto_msgTypes[61] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4347,7 +4477,7 @@ func (x *AddProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use AddProfileResponse.ProtoReflect.Descriptor instead. func (*AddProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{59} + return file_daemon_proto_rawDescGZIP(), []int{61} } type RemoveProfileRequest struct { @@ -4360,7 +4490,7 @@ type RemoveProfileRequest struct { func (x *RemoveProfileRequest) Reset() { *x = RemoveProfileRequest{} - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[62] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4372,7 +4502,7 @@ func (x *RemoveProfileRequest) String() string { func (*RemoveProfileRequest) ProtoMessage() {} func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[60] + mi := &file_daemon_proto_msgTypes[62] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4385,7 +4515,7 @@ func (x *RemoveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileRequest.ProtoReflect.Descriptor instead. func (*RemoveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{60} + return file_daemon_proto_rawDescGZIP(), []int{62} } func (x *RemoveProfileRequest) GetUsername() string { @@ -4410,7 +4540,7 @@ type RemoveProfileResponse struct { func (x *RemoveProfileResponse) Reset() { *x = RemoveProfileResponse{} - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[63] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4422,7 +4552,7 @@ func (x *RemoveProfileResponse) String() string { func (*RemoveProfileResponse) ProtoMessage() {} func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[61] + mi := &file_daemon_proto_msgTypes[63] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4435,7 +4565,7 @@ func (x *RemoveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RemoveProfileResponse.ProtoReflect.Descriptor instead. func (*RemoveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{61} + return file_daemon_proto_rawDescGZIP(), []int{63} } type ListProfilesRequest struct { @@ -4447,7 +4577,7 @@ type ListProfilesRequest struct { func (x *ListProfilesRequest) Reset() { *x = ListProfilesRequest{} - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[64] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4459,7 +4589,7 @@ func (x *ListProfilesRequest) String() string { func (*ListProfilesRequest) ProtoMessage() {} func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[62] + mi := &file_daemon_proto_msgTypes[64] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4472,7 +4602,7 @@ func (x *ListProfilesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesRequest.ProtoReflect.Descriptor instead. func (*ListProfilesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{62} + return file_daemon_proto_rawDescGZIP(), []int{64} } func (x *ListProfilesRequest) GetUsername() string { @@ -4491,7 +4621,7 @@ type ListProfilesResponse struct { func (x *ListProfilesResponse) Reset() { *x = ListProfilesResponse{} - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[65] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4503,7 +4633,7 @@ func (x *ListProfilesResponse) String() string { func (*ListProfilesResponse) ProtoMessage() {} func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[63] + mi := &file_daemon_proto_msgTypes[65] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4516,7 +4646,7 @@ func (x *ListProfilesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListProfilesResponse.ProtoReflect.Descriptor instead. func (*ListProfilesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{63} + return file_daemon_proto_rawDescGZIP(), []int{65} } func (x *ListProfilesResponse) GetProfiles() []*Profile { @@ -4536,7 +4666,7 @@ type Profile struct { func (x *Profile) Reset() { *x = Profile{} - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[66] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4548,7 +4678,7 @@ func (x *Profile) String() string { func (*Profile) ProtoMessage() {} func (x *Profile) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[64] + mi := &file_daemon_proto_msgTypes[66] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4561,7 +4691,7 @@ func (x *Profile) ProtoReflect() protoreflect.Message { // Deprecated: Use Profile.ProtoReflect.Descriptor instead. func (*Profile) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{64} + return file_daemon_proto_rawDescGZIP(), []int{66} } func (x *Profile) GetName() string { @@ -4586,7 +4716,7 @@ type GetActiveProfileRequest struct { func (x *GetActiveProfileRequest) Reset() { *x = GetActiveProfileRequest{} - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[67] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4598,7 +4728,7 @@ func (x *GetActiveProfileRequest) String() string { func (*GetActiveProfileRequest) ProtoMessage() {} func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[65] + mi := &file_daemon_proto_msgTypes[67] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4611,7 +4741,7 @@ func (x *GetActiveProfileRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileRequest.ProtoReflect.Descriptor instead. func (*GetActiveProfileRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{65} + return file_daemon_proto_rawDescGZIP(), []int{67} } type GetActiveProfileResponse struct { @@ -4624,7 +4754,7 @@ type GetActiveProfileResponse struct { func (x *GetActiveProfileResponse) Reset() { *x = GetActiveProfileResponse{} - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4636,7 +4766,7 @@ func (x *GetActiveProfileResponse) String() string { func (*GetActiveProfileResponse) ProtoMessage() {} func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[66] + mi := &file_daemon_proto_msgTypes[68] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4649,7 +4779,7 @@ func (x *GetActiveProfileResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetActiveProfileResponse.ProtoReflect.Descriptor instead. func (*GetActiveProfileResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{66} + return file_daemon_proto_rawDescGZIP(), []int{68} } func (x *GetActiveProfileResponse) GetProfileName() string { @@ -4676,7 +4806,7 @@ type LogoutRequest struct { func (x *LogoutRequest) Reset() { *x = LogoutRequest{} - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[69] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4688,7 +4818,7 @@ func (x *LogoutRequest) String() string { func (*LogoutRequest) ProtoMessage() {} func (x *LogoutRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[67] + mi := &file_daemon_proto_msgTypes[69] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4701,7 +4831,7 @@ func (x *LogoutRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutRequest.ProtoReflect.Descriptor instead. func (*LogoutRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{67} + return file_daemon_proto_rawDescGZIP(), []int{69} } func (x *LogoutRequest) GetProfileName() string { @@ -4726,7 +4856,7 @@ type LogoutResponse struct { func (x *LogoutResponse) Reset() { *x = LogoutResponse{} - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[70] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4738,7 +4868,7 @@ func (x *LogoutResponse) String() string { func (*LogoutResponse) ProtoMessage() {} func (x *LogoutResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[68] + mi := &file_daemon_proto_msgTypes[70] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4751,7 +4881,7 @@ func (x *LogoutResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use LogoutResponse.ProtoReflect.Descriptor instead. func (*LogoutResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{68} + return file_daemon_proto_rawDescGZIP(), []int{70} } type GetFeaturesRequest struct { @@ -4762,7 +4892,7 @@ type GetFeaturesRequest struct { func (x *GetFeaturesRequest) Reset() { *x = GetFeaturesRequest{} - mi := &file_daemon_proto_msgTypes[69] + mi := &file_daemon_proto_msgTypes[71] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4774,7 +4904,7 @@ func (x *GetFeaturesRequest) String() string { func (*GetFeaturesRequest) ProtoMessage() {} func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[69] + mi := &file_daemon_proto_msgTypes[71] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4787,7 +4917,7 @@ func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead. func (*GetFeaturesRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{69} + return file_daemon_proto_rawDescGZIP(), []int{71} } type GetFeaturesResponse struct { @@ -4800,7 +4930,7 @@ type GetFeaturesResponse struct { func (x *GetFeaturesResponse) Reset() { *x = GetFeaturesResponse{} - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[72] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4812,7 +4942,7 @@ func (x *GetFeaturesResponse) String() string { func (*GetFeaturesResponse) ProtoMessage() {} func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[70] + mi := &file_daemon_proto_msgTypes[72] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4825,7 +4955,7 @@ func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead. func (*GetFeaturesResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{70} + return file_daemon_proto_rawDescGZIP(), []int{72} } func (x *GetFeaturesResponse) GetDisableProfiles() bool { @@ -4853,7 +4983,7 @@ type GetPeerSSHHostKeyRequest struct { func (x *GetPeerSSHHostKeyRequest) Reset() { *x = GetPeerSSHHostKeyRequest{} - mi := &file_daemon_proto_msgTypes[71] + mi := &file_daemon_proto_msgTypes[73] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4865,7 +4995,7 @@ func (x *GetPeerSSHHostKeyRequest) String() string { func (*GetPeerSSHHostKeyRequest) ProtoMessage() {} func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[71] + mi := &file_daemon_proto_msgTypes[73] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4878,7 +5008,7 @@ func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{71} + return file_daemon_proto_rawDescGZIP(), []int{73} } func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string { @@ -4905,7 +5035,7 @@ type GetPeerSSHHostKeyResponse struct { func (x *GetPeerSSHHostKeyResponse) Reset() { *x = GetPeerSSHHostKeyResponse{} - mi := &file_daemon_proto_msgTypes[72] + mi := &file_daemon_proto_msgTypes[74] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4917,7 +5047,7 @@ func (x *GetPeerSSHHostKeyResponse) String() string { func (*GetPeerSSHHostKeyResponse) ProtoMessage() {} func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[72] + mi := &file_daemon_proto_msgTypes[74] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4930,7 +5060,7 @@ func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead. func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{72} + return file_daemon_proto_rawDescGZIP(), []int{74} } func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte { @@ -4972,7 +5102,7 @@ type RequestJWTAuthRequest struct { func (x *RequestJWTAuthRequest) Reset() { *x = RequestJWTAuthRequest{} - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -4984,7 +5114,7 @@ func (x *RequestJWTAuthRequest) String() string { func (*RequestJWTAuthRequest) ProtoMessage() {} func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[73] + mi := &file_daemon_proto_msgTypes[75] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -4997,7 +5127,7 @@ func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead. func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{73} + return file_daemon_proto_rawDescGZIP(), []int{75} } func (x *RequestJWTAuthRequest) GetHint() string { @@ -5030,7 +5160,7 @@ type RequestJWTAuthResponse struct { func (x *RequestJWTAuthResponse) Reset() { *x = RequestJWTAuthResponse{} - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5042,7 +5172,7 @@ func (x *RequestJWTAuthResponse) String() string { func (*RequestJWTAuthResponse) ProtoMessage() {} func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[74] + mi := &file_daemon_proto_msgTypes[76] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5055,7 +5185,7 @@ func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead. func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{74} + return file_daemon_proto_rawDescGZIP(), []int{76} } func (x *RequestJWTAuthResponse) GetVerificationURI() string { @@ -5120,7 +5250,7 @@ type WaitJWTTokenRequest struct { func (x *WaitJWTTokenRequest) Reset() { *x = WaitJWTTokenRequest{} - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5132,7 +5262,7 @@ func (x *WaitJWTTokenRequest) String() string { func (*WaitJWTTokenRequest) ProtoMessage() {} func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[75] + mi := &file_daemon_proto_msgTypes[77] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5145,7 +5275,7 @@ func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead. func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{75} + return file_daemon_proto_rawDescGZIP(), []int{77} } func (x *WaitJWTTokenRequest) GetDeviceCode() string { @@ -5177,7 +5307,7 @@ type WaitJWTTokenResponse struct { func (x *WaitJWTTokenResponse) Reset() { *x = WaitJWTTokenResponse{} - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5189,7 +5319,7 @@ func (x *WaitJWTTokenResponse) String() string { func (*WaitJWTTokenResponse) ProtoMessage() {} func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[76] + mi := &file_daemon_proto_msgTypes[78] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5202,7 +5332,7 @@ func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead. func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{76} + return file_daemon_proto_rawDescGZIP(), []int{78} } func (x *WaitJWTTokenResponse) GetToken() string { @@ -5236,7 +5366,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5248,7 +5378,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[78] + mi := &file_daemon_proto_msgTypes[80] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5261,7 +5391,7 @@ func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { // Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. func (*PortInfo_Range) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{28, 0} + return file_daemon_proto_rawDescGZIP(), []int{30, 0} } func (x *PortInfo_Range) GetStart() uint32 { @@ -5283,7 +5413,15 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xb6\x12\n" + + "\fEmptyRequest\"\x7f\n" + + "\x12OSLifecycleRequest\x128\n" + + "\x04type\x18\x01 \x01(\x0e2$.daemon.OSLifecycleRequest.CycleTypeR\x04type\"/\n" + + "\tCycleType\x12\v\n" + + "\aUNKNOWN\x10\x00\x12\t\n" + + "\x05SLEEP\x10\x01\x12\n" + + "\n" + + "\x06WAKEUP\x10\x02\"\x15\n" + + "\x13OSLifecycleResponse\"\xb6\x12\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -5764,7 +5902,7 @@ const file_daemon_proto_rawDesc = "" + "\x04WARN\x10\x04\x12\b\n" + "\x04INFO\x10\x05\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" + - "\x05TRACE\x10\a2\x8b\x12\n" + + "\x05TRACE\x10\a2\xdb\x12\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -5799,7 +5937,8 @@ const file_daemon_proto_rawDesc = "" + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" + "\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" + - "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00B\bZ\x06/protob\x06proto3" + "\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" + + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_daemon_proto_rawDescOnce sync.Once @@ -5813,196 +5952,202 @@ func file_daemon_proto_rawDescGZIP() []byte { return file_daemon_proto_rawDescData } -var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 80) +var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 82) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel - (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity - (SystemEvent_Category)(0), // 2: daemon.SystemEvent.Category - (*EmptyRequest)(nil), // 3: daemon.EmptyRequest - (*LoginRequest)(nil), // 4: daemon.LoginRequest - (*LoginResponse)(nil), // 5: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 6: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 7: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 8: daemon.UpRequest - (*UpResponse)(nil), // 9: daemon.UpResponse - (*StatusRequest)(nil), // 10: daemon.StatusRequest - (*StatusResponse)(nil), // 11: daemon.StatusResponse - (*DownRequest)(nil), // 12: daemon.DownRequest - (*DownResponse)(nil), // 13: daemon.DownResponse - (*GetConfigRequest)(nil), // 14: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 15: daemon.GetConfigResponse - (*PeerState)(nil), // 16: daemon.PeerState - (*LocalPeerState)(nil), // 17: daemon.LocalPeerState - (*SignalState)(nil), // 18: daemon.SignalState - (*ManagementState)(nil), // 19: daemon.ManagementState - (*RelayState)(nil), // 20: daemon.RelayState - (*NSGroupState)(nil), // 21: daemon.NSGroupState - (*SSHSessionInfo)(nil), // 22: daemon.SSHSessionInfo - (*SSHServerState)(nil), // 23: daemon.SSHServerState - (*FullStatus)(nil), // 24: daemon.FullStatus - (*ListNetworksRequest)(nil), // 25: daemon.ListNetworksRequest - (*ListNetworksResponse)(nil), // 26: daemon.ListNetworksResponse - (*SelectNetworksRequest)(nil), // 27: daemon.SelectNetworksRequest - (*SelectNetworksResponse)(nil), // 28: daemon.SelectNetworksResponse - (*IPList)(nil), // 29: daemon.IPList - (*Network)(nil), // 30: daemon.Network - (*PortInfo)(nil), // 31: daemon.PortInfo - (*ForwardingRule)(nil), // 32: daemon.ForwardingRule - (*ForwardingRulesResponse)(nil), // 33: daemon.ForwardingRulesResponse - (*DebugBundleRequest)(nil), // 34: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 35: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 36: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 37: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 38: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 39: daemon.SetLogLevelResponse - (*State)(nil), // 40: daemon.State - (*ListStatesRequest)(nil), // 41: daemon.ListStatesRequest - (*ListStatesResponse)(nil), // 42: daemon.ListStatesResponse - (*CleanStateRequest)(nil), // 43: daemon.CleanStateRequest - (*CleanStateResponse)(nil), // 44: daemon.CleanStateResponse - (*DeleteStateRequest)(nil), // 45: daemon.DeleteStateRequest - (*DeleteStateResponse)(nil), // 46: daemon.DeleteStateResponse - (*SetSyncResponsePersistenceRequest)(nil), // 47: daemon.SetSyncResponsePersistenceRequest - (*SetSyncResponsePersistenceResponse)(nil), // 48: daemon.SetSyncResponsePersistenceResponse - (*TCPFlags)(nil), // 49: daemon.TCPFlags - (*TracePacketRequest)(nil), // 50: daemon.TracePacketRequest - (*TraceStage)(nil), // 51: daemon.TraceStage - (*TracePacketResponse)(nil), // 52: daemon.TracePacketResponse - (*SubscribeRequest)(nil), // 53: daemon.SubscribeRequest - (*SystemEvent)(nil), // 54: daemon.SystemEvent - (*GetEventsRequest)(nil), // 55: daemon.GetEventsRequest - (*GetEventsResponse)(nil), // 56: daemon.GetEventsResponse - (*SwitchProfileRequest)(nil), // 57: daemon.SwitchProfileRequest - (*SwitchProfileResponse)(nil), // 58: daemon.SwitchProfileResponse - (*SetConfigRequest)(nil), // 59: daemon.SetConfigRequest - (*SetConfigResponse)(nil), // 60: daemon.SetConfigResponse - (*AddProfileRequest)(nil), // 61: daemon.AddProfileRequest - (*AddProfileResponse)(nil), // 62: daemon.AddProfileResponse - (*RemoveProfileRequest)(nil), // 63: daemon.RemoveProfileRequest - (*RemoveProfileResponse)(nil), // 64: daemon.RemoveProfileResponse - (*ListProfilesRequest)(nil), // 65: daemon.ListProfilesRequest - (*ListProfilesResponse)(nil), // 66: daemon.ListProfilesResponse - (*Profile)(nil), // 67: daemon.Profile - (*GetActiveProfileRequest)(nil), // 68: daemon.GetActiveProfileRequest - (*GetActiveProfileResponse)(nil), // 69: daemon.GetActiveProfileResponse - (*LogoutRequest)(nil), // 70: daemon.LogoutRequest - (*LogoutResponse)(nil), // 71: daemon.LogoutResponse - (*GetFeaturesRequest)(nil), // 72: daemon.GetFeaturesRequest - (*GetFeaturesResponse)(nil), // 73: daemon.GetFeaturesResponse - (*GetPeerSSHHostKeyRequest)(nil), // 74: daemon.GetPeerSSHHostKeyRequest - (*GetPeerSSHHostKeyResponse)(nil), // 75: daemon.GetPeerSSHHostKeyResponse - (*RequestJWTAuthRequest)(nil), // 76: daemon.RequestJWTAuthRequest - (*RequestJWTAuthResponse)(nil), // 77: daemon.RequestJWTAuthResponse - (*WaitJWTTokenRequest)(nil), // 78: daemon.WaitJWTTokenRequest - (*WaitJWTTokenResponse)(nil), // 79: daemon.WaitJWTTokenResponse - nil, // 80: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 81: daemon.PortInfo.Range - nil, // 82: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 83: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 84: google.protobuf.Timestamp + (OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType + (SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity + (SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category + (*EmptyRequest)(nil), // 4: daemon.EmptyRequest + (*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest + (*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse + (*LoginRequest)(nil), // 7: daemon.LoginRequest + (*LoginResponse)(nil), // 8: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 11: daemon.UpRequest + (*UpResponse)(nil), // 12: daemon.UpResponse + (*StatusRequest)(nil), // 13: daemon.StatusRequest + (*StatusResponse)(nil), // 14: daemon.StatusResponse + (*DownRequest)(nil), // 15: daemon.DownRequest + (*DownResponse)(nil), // 16: daemon.DownResponse + (*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse + (*PeerState)(nil), // 19: daemon.PeerState + (*LocalPeerState)(nil), // 20: daemon.LocalPeerState + (*SignalState)(nil), // 21: daemon.SignalState + (*ManagementState)(nil), // 22: daemon.ManagementState + (*RelayState)(nil), // 23: daemon.RelayState + (*NSGroupState)(nil), // 24: daemon.NSGroupState + (*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo + (*SSHServerState)(nil), // 26: daemon.SSHServerState + (*FullStatus)(nil), // 27: daemon.FullStatus + (*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse + (*IPList)(nil), // 32: daemon.IPList + (*Network)(nil), // 33: daemon.Network + (*PortInfo)(nil), // 34: daemon.PortInfo + (*ForwardingRule)(nil), // 35: daemon.ForwardingRule + (*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse + (*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse + (*State)(nil), // 43: daemon.State + (*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse + (*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest + (*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse + (*TCPFlags)(nil), // 52: daemon.TCPFlags + (*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest + (*TraceStage)(nil), // 54: daemon.TraceStage + (*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse + (*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest + (*SystemEvent)(nil), // 57: daemon.SystemEvent + (*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest + (*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse + (*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest + (*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse + (*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest + (*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse + (*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest + (*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse + (*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest + (*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse + (*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest + (*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse + (*Profile)(nil), // 70: daemon.Profile + (*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest + (*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse + (*LogoutRequest)(nil), // 73: daemon.LogoutRequest + (*LogoutResponse)(nil), // 74: daemon.LogoutResponse + (*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest + (*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse + (*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest + (*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse + (*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest + (*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse + (*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest + (*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse + nil, // 83: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 84: daemon.PortInfo.Range + nil, // 85: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 86: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 87: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 83, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 24, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 84, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 84, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 83, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 22, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo - 19, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 18, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 17, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 16, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState - 20, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState - 21, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 54, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 23, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState - 30, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 80, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 81, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 31, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 31, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 32, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule - 0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 40, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State - 49, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 51, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 1, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 2, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 84, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 82, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 54, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 83, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 67, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 29, // 32: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 4, // 33: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 6, // 34: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 8, // 35: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 10, // 36: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 12, // 37: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 14, // 38: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 25, // 39: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 27, // 40: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 27, // 41: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 3, // 42: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 34, // 43: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 36, // 44: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 38, // 45: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 41, // 46: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 43, // 47: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 45, // 48: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 47, // 49: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 50, // 50: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 53, // 51: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 55, // 52: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 57, // 53: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 59, // 54: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 61, // 55: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 63, // 56: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 65, // 57: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 68, // 58: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 70, // 59: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 72, // 60: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 74, // 61: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest - 76, // 62: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest - 78, // 63: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 5, // 64: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 7, // 65: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 9, // 66: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 11, // 67: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 13, // 68: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 15, // 69: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 26, // 70: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 28, // 71: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 28, // 72: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 33, // 73: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 35, // 74: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 37, // 75: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 39, // 76: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 42, // 77: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 44, // 78: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 46, // 79: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 48, // 80: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 52, // 81: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 54, // 82: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 56, // 83: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 58, // 84: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 60, // 85: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 62, // 86: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 64, // 87: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 66, // 88: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 69, // 89: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 71, // 90: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 73, // 91: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 75, // 92: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 77, // 93: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 79, // 94: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 64, // [64:95] is the sub-list for method output_type - 33, // [33:64] is the sub-list for method input_type - 33, // [33:33] is the sub-list for extension type_name - 33, // [33:33] is the sub-list for extension extendee - 0, // [0:33] is the sub-list for field type_name + 1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType + 86, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 87, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 87, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 86, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState + 23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState + 24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 83, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 84, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State + 52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 87, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 85, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 86, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest + 8, // 66: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 10, // 67: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 12, // 68: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 14, // 69: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 16, // 70: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 18, // 71: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 29, // 72: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 31, // 73: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 31, // 74: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 36, // 75: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 38, // 76: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 40, // 77: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 42, // 78: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 45, // 79: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 47, // 80: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 49, // 81: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 51, // 82: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 55, // 83: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 57, // 84: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 59, // 85: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 61, // 86: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 63, // 87: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 65, // 88: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 67, // 89: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 69, // 90: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 72, // 91: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 74, // 92: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 76, // 93: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 78, // 94: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 95: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 96: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 6, // 97: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse + 66, // [66:98] is the sub-list for method output_type + 34, // [34:66] is the sub-list for method input_type + 34, // [34:34] is the sub-list for extension type_name + 34, // [34:34] is the sub-list for extension extendee + 0, // [0:34] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -6010,26 +6155,26 @@ func file_daemon_proto_init() { if File_daemon_proto != nil { return } - file_daemon_proto_msgTypes[1].OneofWrappers = []any{} - file_daemon_proto_msgTypes[5].OneofWrappers = []any{} + file_daemon_proto_msgTypes[3].OneofWrappers = []any{} file_daemon_proto_msgTypes[7].OneofWrappers = []any{} - file_daemon_proto_msgTypes[28].OneofWrappers = []any{ + file_daemon_proto_msgTypes[9].OneofWrappers = []any{} + file_daemon_proto_msgTypes[30].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), } - file_daemon_proto_msgTypes[47].OneofWrappers = []any{} - file_daemon_proto_msgTypes[48].OneofWrappers = []any{} - file_daemon_proto_msgTypes[54].OneofWrappers = []any{} + file_daemon_proto_msgTypes[49].OneofWrappers = []any{} + file_daemon_proto_msgTypes[50].OneofWrappers = []any{} file_daemon_proto_msgTypes[56].OneofWrappers = []any{} - file_daemon_proto_msgTypes[67].OneofWrappers = []any{} - file_daemon_proto_msgTypes[73].OneofWrappers = []any{} + file_daemon_proto_msgTypes[58].OneofWrappers = []any{} + file_daemon_proto_msgTypes[69].OneofWrappers = []any{} + file_daemon_proto_msgTypes[75].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), - NumEnums: 3, - NumMessages: 80, + NumEnums: 4, + NumMessages: 82, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index bf8553706..3dfd3da8d 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -24,7 +24,7 @@ service DaemonService { // Status of the service. rpc Status(StatusRequest) returns (StatusResponse) {} - // Down engine work in the daemon. + // Down stops engine work in the daemon. rpc Down(DownRequest) returns (DownResponse) {} // GetConfig of the daemon. @@ -93,9 +93,26 @@ service DaemonService { // WaitJWTToken waits for JWT authentication completion rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {} + + rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} } + +message OSLifecycleRequest { + // avoid collision with loglevel enum + enum CycleType { + UNKNOWN = 0; + SLEEP = 1; + WAKEUP = 2; + } + + CycleType type = 1; +} + +message OSLifecycleResponse {} + + message LoginRequest { // setupKey netbird setup key. string setupKey = 1; diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index b2bf716b2..6b01309b7 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -27,7 +27,7 @@ type DaemonServiceClient interface { Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) // Status of the service. Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) - // Down engine work in the daemon. + // Down stops engine work in the daemon. Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) // GetConfig of the daemon. GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) @@ -70,6 +70,7 @@ type DaemonServiceClient interface { RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) // WaitJWTToken waits for JWT authentication completion WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) + NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) } type daemonServiceClient struct { @@ -382,6 +383,15 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken return out, nil } +func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) { + out := new(OSLifecycleResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -395,7 +405,7 @@ type DaemonServiceServer interface { Up(context.Context, *UpRequest) (*UpResponse, error) // Status of the service. Status(context.Context, *StatusRequest) (*StatusResponse, error) - // Down engine work in the daemon. + // Down stops engine work in the daemon. Down(context.Context, *DownRequest) (*DownResponse, error) // GetConfig of the daemon. GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) @@ -438,6 +448,7 @@ type DaemonServiceServer interface { RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) // WaitJWTToken waits for JWT authentication completion WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) + NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -538,6 +549,9 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented") } +func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -1112,6 +1126,24 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d return interceptor(ctx, in, info, handler) } +func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(OSLifecycleRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/NotifyOSLifecycle", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -1239,6 +1271,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "WaitJWTToken", Handler: _DaemonService_WaitJWTToken_Handler, }, + { + MethodName: "NotifyOSLifecycle", + Handler: _DaemonService_NotifyOSLifecycle_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/client/server/lifecycle.go b/client/server/lifecycle.go new file mode 100644 index 000000000..3722c027d --- /dev/null +++ b/client/server/lifecycle.go @@ -0,0 +1,77 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/proto" +) + +// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type. +func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) { + switch req.GetType() { + case proto.OSLifecycleRequest_WAKEUP: + return s.handleWakeUp(callerCtx) + case proto.OSLifecycleRequest_SLEEP: + return s.handleSleep(callerCtx) + default: + log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType()) + } + return &proto.OSLifecycleResponse{}, nil +} + +// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep. +// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails. +func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) { + if !s.sleepTriggeredDown.Load() { + log.Info("skipping up because wasn't sleep down") + return &proto.OSLifecycleResponse{}, nil + } + + // avoid other wakeup runs if sleep didn't make the computer sleep + s.sleepTriggeredDown.Store(false) + + log.Info("running up after wake up") + _, err := s.Up(callerCtx, &proto.UpRequest{}) + if err != nil { + log.Errorf("running up failed: %v", err) + return &proto.OSLifecycleResponse{}, err + } + + log.Info("running up command executed successfully") + return &proto.OSLifecycleResponse{}, nil +} + +// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state. +func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) { + s.mutex.Lock() + + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + s.mutex.Unlock() + return &proto.OSLifecycleResponse{}, err + } + + if status != internal.StatusConnecting && status != internal.StatusConnected { + log.Infof("skipping setting the agent down because status is %s", status) + s.mutex.Unlock() + return &proto.OSLifecycleResponse{}, nil + } + s.mutex.Unlock() + + log.Info("running down after system started sleeping") + + _, err = s.Down(callerCtx, &proto.DownRequest{}) + if err != nil { + log.Errorf("running down failed: %v", err) + return &proto.OSLifecycleResponse{}, err + } + + s.sleepTriggeredDown.Store(true) + + log.Info("running down executed successfully") + return &proto.OSLifecycleResponse{}, nil +} diff --git a/client/server/lifecycle_test.go b/client/server/lifecycle_test.go new file mode 100644 index 000000000..a604c60af --- /dev/null +++ b/client/server/lifecycle_test.go @@ -0,0 +1,219 @@ +package server + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/proto" +) + +func newTestServer() *Server { + ctx := internal.CtxInitState(context.Background()) + return &Server{ + rootCtx: ctx, + statusRecorder: peer.NewRecorder(""), + } +} + +func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) { + s := newTestServer() + + // sleepTriggeredDown is false by default + assert.False(t, s.sleepTriggeredDown.Load()) + + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false") +} + +func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusIdle) + + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle") +} + +func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusNeedsLogin) + + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin") +} + +func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusConnecting) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.actCancel = cancel + + resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + assert.NotNil(t, resp, "handleSleep returns not nil response on success") + assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting") +} + +func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) { + s := newTestServer() + + state := internal.CtxGetState(s.rootCtx) + state.Set(internal.StatusConnected) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.actCancel = cancel + + resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_SLEEP, + }) + + require.NoError(t, err) + assert.NotNil(t, resp, "handleSleep returns not nil response on success") + assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected") +} + +func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) { + s := newTestServer() + + // Manually set the flag to simulate prior sleep down + s.sleepTriggeredDown.Store(true) + + // WakeUp will try to call Up which fails without proper setup, but flag should reset first + _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + + assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt") +} + +func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) { + s := newTestServer() + + // First wakeup without prior sleep - should be no-op + resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load()) + + // Simulate prior sleep + s.sleepTriggeredDown.Store(true) + + // First wakeup after sleep - should reset flag + _, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + assert.False(t, s.sleepTriggeredDown.Load()) + + // Second wakeup - should be no-op + resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{ + Type: proto.OSLifecycleRequest_WAKEUP, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load()) +} + +func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) { + s := newTestServer() + + resp, err := s.handleWakeUp(context.Background()) + + require.NoError(t, err) + require.NotNil(t, resp) +} + +func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) { + s := newTestServer() + s.sleepTriggeredDown.Store(true) + + // Even if Up fails, flag should be reset + _, _ = s.handleWakeUp(context.Background()) + + assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up") +} + +func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) { + tests := []struct { + name string + status internal.StatusType + }{ + {"Idle", internal.StatusIdle}, + {"NeedsLogin", internal.StatusNeedsLogin}, + {"LoginFailed", internal.StatusLoginFailed}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := newTestServer() + state := internal.CtxGetState(s.rootCtx) + state.Set(tt.status) + + resp, err := s.handleSleep(context.Background()) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, s.sleepTriggeredDown.Load()) + }) + } +} + +func TestHandleSleep_ProceedsForActiveStates(t *testing.T) { + tests := []struct { + name string + status internal.StatusType + }{ + {"Connecting", internal.StatusConnecting}, + {"Connected", internal.StatusConnected}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := newTestServer() + state := internal.CtxGetState(s.rootCtx) + state.Set(tt.status) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.actCancel = cancel + + resp, err := s.handleSleep(ctx) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, s.sleepTriggeredDown.Load()) + }) + } +} diff --git a/client/server/server.go b/client/server/server.go index 49000c092..cf79ec5ed 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -85,6 +85,9 @@ type Server struct { profilesDisabled bool updateSettingsDisabled bool + // sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down + sleepTriggeredDown atomic.Bool + jwtCache *jwtCache } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 57d0e74a0..8f99608e7 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -1196,25 +1196,27 @@ func (s *serviceClient) handleSleepEvents(event sleep.EventType) { return } + req := &proto.OSLifecycleRequest{} + switch event { case sleep.EventTypeWakeUp: log.Infof("handle wakeup event: %v", event) - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - return - } - return + req.Type = proto.OSLifecycleRequest_WAKEUP case sleep.EventTypeSleep: log.Infof("handle sleep event: %v", event) - _, err = conn.Down(s.ctx, &proto.DownRequest{}) - if err != nil { - log.Errorf("down service: %v", err) - return - } + req.Type = proto.OSLifecycleRequest_SLEEP + default: + log.Infof("unknown event: %v", event) + return } - log.Info("successfully notified daemon about sleep/wakeup event") + _, err = conn.NotifyOSLifecycle(s.ctx, req) + if err != nil { + log.Errorf("failed to notify daemon about os lifecycle notification: %v", err) + return + } + + log.Info("successfully notified daemon about os lifecycle") } // setSettingsEnabled enables or disables the settings menu based on the provided state From 27dd97c9c40ae889fd1fc936ddc7cdb67d290300 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 3 Dec 2025 14:45:59 +0300 Subject: [PATCH 105/120] [management] Add support to disable geolocation service (#4901) --- management/internals/server/modules.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 91ce50a79..af9ca5f2d 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -2,6 +2,7 @@ package server import ( "context" + "os" log "github.com/sirupsen/logrus" @@ -21,7 +22,16 @@ import ( "github.com/netbirdio/netbird/management/server/users" ) +const ( + geolocationDisabledKey = "NB_DISABLE_GEOLOCATION" +) + func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { + if os.Getenv(geolocationDisabledKey) == "true" { + log.Info("geolocation service is disabled, skipping initialization") + return nil + } + return Create(s, func() geolocation.Geolocation { geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate) if err != nil { From d2e48d4f5e3dba008393b3cee233ed9877ee1f9a Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 3 Dec 2025 18:42:53 +0100 Subject: [PATCH 106/120] [relay] Use instanceURL instead of Exposed address. (#4905) Replaces string-based exposed address handling with URL-based InstanceURL() (type url.URL) across relay/server and relay/healthcheck; adds SchemeREL/SchemeRELS constants; updates getInstanceURL to return *url.URL with scheme and TLS validation; adjusts WS dialing and health-check logic to use URL fields. --- relay/cmd/root.go | 3 ++- relay/healthcheck/healthcheck.go | 18 +++++++---------- relay/healthcheck/ws.go | 12 +++++------ relay/server/relay.go | 16 ++++++--------- relay/server/server.go | 12 ++++------- relay/server/url.go | 22 +++++++++++++++------ relay/server/{relay_test.go => url_test.go} | 9 ++++++--- 7 files changed, 47 insertions(+), 45 deletions(-) rename relay/server/{relay_test.go => url_test.go} (78%) diff --git a/relay/cmd/root.go b/relay/cmd/root.go index eb2cdebf8..e7dadcfdf 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -160,7 +160,8 @@ func execute(cmd *cobra.Command, args []string) error { log.Debugf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err) } - log.Infof("server will be available on: %s", srv.InstanceURL()) + instanceURL := srv.InstanceURL() + log.Infof("server will be available on: %s", instanceURL.String()) wg.Add(1) go func() { defer wg.Done() diff --git a/relay/healthcheck/healthcheck.go b/relay/healthcheck/healthcheck.go index 6463843eb..b54d4b33b 100644 --- a/relay/healthcheck/healthcheck.go +++ b/relay/healthcheck/healthcheck.go @@ -6,13 +6,14 @@ import ( "errors" "net" "net/http" - "strings" + "net/url" "sync" "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" + "github.com/netbirdio/netbird/relay/server" ) const ( @@ -26,7 +27,7 @@ const ( type ServiceChecker interface { ListenerProtocols() []protocol.Protocol - ExposedAddress() string + InstanceURL() url.URL } type HealthStatus struct { @@ -134,7 +135,7 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) { } status.Listeners = listeners - if !strings.HasPrefix(s.config.ServiceChecker.ExposedAddress(), "rels") { + if s.config.ServiceChecker.InstanceURL().Scheme != server.SchemeRELS { status.CertificateValid = false } @@ -156,14 +157,9 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) { } func (s *Server) validateConnection(ctx context.Context) bool { - exposedAddress := s.config.ServiceChecker.ExposedAddress() - if exposedAddress == "" { - log.Error("exposed address is empty, cannot validate certificate") - return false - } - - if err := dialWS(ctx, exposedAddress); err != nil { - log.Errorf("failed to dial WebSocket listener at %s: %v", exposedAddress, err) + addr := s.config.ServiceChecker.InstanceURL() + if err := dialWS(ctx, addr); err != nil { + log.Errorf("failed to dial WebSocket listener at %s: %v", addr.String(), err) return false } diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go index badd31219..db61ed802 100644 --- a/relay/healthcheck/ws.go +++ b/relay/healthcheck/ws.go @@ -3,22 +3,22 @@ package healthcheck import ( "context" "fmt" - "strings" + "net/url" "github.com/coder/websocket" + "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay" ) -func dialWS(ctx context.Context, address string) error { - addressSplit := strings.Split(address, "/") +func dialWS(ctx context.Context, address url.URL) error { scheme := "ws" - if addressSplit[0] == "rels:" { + if address.Scheme == server.SchemeRELS { scheme = "wss" } - url := fmt.Sprintf("%s://%s%s", scheme, addressSplit[2], relay.WebSocketURLPath) + wsURL := fmt.Sprintf("%s://%s%s", scheme, address.Host, relay.WebSocketURLPath) - conn, resp, err := websocket.Dial(ctx, url, nil) + conn, resp, err := websocket.Dial(ctx, wsURL, nil) if resp != nil { defer func() { if resp.Body != nil { diff --git a/relay/server/relay.go b/relay/server/relay.go index aab575bf0..c1cfa13fd 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/url" "sync" "time" @@ -22,7 +23,7 @@ type Config struct { TLSSupport bool AuthValidator Validator - instanceURL string + instanceURL url.URL } func (c *Config) validate() error { @@ -37,7 +38,7 @@ func (c *Config) validate() error { if err != nil { return fmt.Errorf("invalid url: %v", err) } - c.instanceURL = instanceURL + c.instanceURL = *instanceURL if c.AuthValidator == nil { return fmt.Errorf("auth validator is required") @@ -53,7 +54,7 @@ type Relay struct { store *store.Store notifier *store.PeerNotifier - instanceURL string + instanceURL url.URL exposedAddress string preparedMsg *preparedMsg @@ -97,7 +98,7 @@ func NewRelay(config Config) (*Relay, error) { notifier: store.NewPeerNotifier(), } - r.preparedMsg, err = newPreparedMsg(r.instanceURL) + r.preparedMsg, err = newPreparedMsg(r.instanceURL.String()) if err != nil { metricsCancel() return nil, fmt.Errorf("prepare message: %v", err) @@ -177,11 +178,6 @@ func (r *Relay) Shutdown(ctx context.Context) { } // InstanceURL returns the instance URL of the relay server -func (r *Relay) InstanceURL() string { +func (r *Relay) InstanceURL() url.URL { return r.instanceURL } - -// ExposedAddress returns the exposed address (domain:port) where clients connect -func (r *Relay) ExposedAddress() string { - return r.exposedAddress -} diff --git a/relay/server/server.go b/relay/server/server.go index 2c9e658d6..8e4333064 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/tls" + "net/url" "sync" "github.com/hashicorp/go-multierror" @@ -39,7 +40,7 @@ type Server struct { // // config: A Config struct containing the necessary configuration: // - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used. -// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required. +// - InstanceURL: The public address (in domain:port format) used as the server's instance URL. Required. // - TLSSupport: A boolean indicating whether TLS is enabled for the server. // - AuthValidator: A Validator used to authenticate peers. Required. // @@ -119,11 +120,6 @@ func (r *Server) Shutdown(ctx context.Context) error { return nberrors.FormatErrorOrNil(multiErr) } -// InstanceURL returns the instance URL of the relay server. -func (r *Server) InstanceURL() string { - return r.relay.instanceURL -} - func (r *Server) ListenerProtocols() []protocol.Protocol { result := make([]protocol.Protocol, 0) @@ -135,6 +131,6 @@ func (r *Server) ListenerProtocols() []protocol.Protocol { return result } -func (r *Server) ExposedAddress() string { - return r.relay.ExposedAddress() +func (r *Server) InstanceURL() url.URL { + return r.relay.InstanceURL() } diff --git a/relay/server/url.go b/relay/server/url.go index 9cbf44642..aeae1c068 100644 --- a/relay/server/url.go +++ b/relay/server/url.go @@ -6,9 +6,14 @@ import ( "strings" ) +const ( + SchemeREL = "rel" + SchemeRELS = "rels" +) + // getInstanceURL checks if user supplied a URL scheme otherwise adds to the // provided address according to TLS definition and parses the address before returning it -func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { +func getInstanceURL(exposedAddress string, tlsSupported bool) (*url.URL, error) { addr := exposedAddress split := strings.Split(exposedAddress, "://") switch { @@ -17,17 +22,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { case len(split) == 1 && !tlsSupported: addr = "rel://" + exposedAddress case len(split) > 2: - return "", fmt.Errorf("invalid exposed address: %s", exposedAddress) + return nil, fmt.Errorf("invalid exposed address: %s", exposedAddress) } parsedURL, err := url.ParseRequestURI(addr) if err != nil { - return "", fmt.Errorf("invalid exposed address: %v", err) + return nil, fmt.Errorf("invalid exposed address: %v", err) } - if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" { - return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) + if parsedURL.Scheme != SchemeREL && parsedURL.Scheme != SchemeRELS { + return nil, fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) } - return parsedURL.String(), nil + // Validate scheme matches TLS configuration + if tlsSupported && parsedURL.Scheme == SchemeREL { + return nil, fmt.Errorf("non-TLS scheme '%s' provided but TLS is supported", SchemeREL) + } + + return parsedURL, nil } diff --git a/relay/server/relay_test.go b/relay/server/url_test.go similarity index 78% rename from relay/server/relay_test.go rename to relay/server/url_test.go index 062039ab9..ca455f45a 100644 --- a/relay/server/relay_test.go +++ b/relay/server/url_test.go @@ -13,7 +13,7 @@ func TestGetInstanceURL(t *testing.T) { {"Valid address with TLS", "example.com", true, "rels://example.com", false}, {"Valid address without TLS", "example.com", false, "rel://example.com", false}, {"Valid address with scheme", "rel://example.com", false, "rel://example.com", false}, - {"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false}, + {"Invalid address with non TLS scheme and TLS true", "rel://example.com", true, "", true}, {"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false}, {"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false}, {"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false}, @@ -28,8 +28,11 @@ func TestGetInstanceURL(t *testing.T) { if (err != nil) != tt.expectError { t.Errorf("expected error: %v, got: %v", tt.expectError, err) } - if url != tt.expectedURL { - t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url) + if !tt.expectError && url != nil && url.String() != tt.expectedURL { + t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url.String()) + } + if tt.expectError && url != nil { + t.Errorf("expected nil URL on error, got: %s", url.String()) } }) } From 031ab1117800fec893f4c433cbd1c3999882e6b4 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 4 Dec 2025 16:57:29 +0300 Subject: [PATCH 107/120] [client] Remove select account prompt (#4912) Signed-off-by: bcmmbaga --- client/internal/auth/pkce_flow.go | 3 +-- client/internal/auth/pkce_flow_test.go | 13 ++++++------- shared/management/client/common/types.go | 8 ++++---- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index c03376f5b..cc43c8648 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -107,10 +107,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn if !p.providerConfig.DisablePromptLogin { switch p.providerConfig.LoginFlag { case common.LoginFlagPromptLogin: - params = append(params, oauth2.SetAuthURLParam("prompt", "login select_account")) + params = append(params, oauth2.SetAuthURLParam("prompt", "login")) case common.LoginFlagMaxAge0: params = append(params, oauth2.SetAuthURLParam("max_age", "0")) - params = append(params, oauth2.SetAuthURLParam("prompt", "select_account")) } } if p.providerConfig.LoginHint != "" { diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index b5843f104..b77a17eaa 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -15,9 +15,8 @@ import ( func TestPromptLogin(t *testing.T) { const ( - promptSelectAccountLogin = "prompt=login+select_account" - promptSelectAccount = "prompt=select_account" - maxAge0 = "max_age=0" + promptLogin = "prompt=login" + maxAge0 = "max_age=0" ) tt := []struct { @@ -27,14 +26,14 @@ func TestPromptLogin(t *testing.T) { expectContains []string }{ { - name: "Prompt login with select account", + name: "Prompt login", loginFlag: mgm.LoginFlagPromptLogin, - expectContains: []string{promptSelectAccountLogin}, + expectContains: []string{promptLogin}, }, { - name: "Max age 0 with select account", + name: "Max age 0", loginFlag: mgm.LoginFlagMaxAge0, - expectContains: []string{maxAge0, promptSelectAccount}, + expectContains: []string{maxAge0}, }, { name: "Disable prompt login", diff --git a/shared/management/client/common/types.go b/shared/management/client/common/types.go index 550bcde30..451578358 100644 --- a/shared/management/client/common/types.go +++ b/shared/management/client/common/types.go @@ -6,14 +6,14 @@ package common // // | Value | Flag | OAuth Parameters | // |-------|----------------------|-----------------------------------------| -// | 0 | LoginFlagPromptLogin | prompt=select_account login | -// | 1 | LoginFlagMaxAge0 | max_age=0 & prompt=select_account | +// | 0 | LoginFlagPromptLogin | prompt=login | +// | 1 | LoginFlagMaxAge0 | max_age=0 | type LoginFlag uint8 const ( - // LoginFlagPromptLogin adds prompt=select_account login to the authorization request + // LoginFlagPromptLogin adds prompt=login to the authorization request LoginFlagPromptLogin LoginFlag = iota - // LoginFlagMaxAge0 adds max_age=0 and prompt=select_account to the authorization request + // LoginFlagMaxAge0 adds max_age=0 to the authorization request LoginFlagMaxAge0 // LoginFlagNone disables all login flags LoginFlagNone From 9bdc4908fb7a32001cebaa03ac754f3aa82a5e4c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 4 Dec 2025 19:16:38 +0100 Subject: [PATCH 108/120] [client] Passthrough all non-NetBird chains to prevent them from dropping NetBird traffic (#4899) --- client/firewall/nftables/router_linux.go | 279 ++++++++++++++++++----- 1 file changed, 225 insertions(+), 54 deletions(-) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e4debc179..7f95992da 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -27,7 +27,11 @@ import ( ) const ( - tableNat = "nat" + tableNat = "nat" + tableMangle = "mangle" + tableRaw = "raw" + tableSecurity = "security" + chainNameNatPrerouting = "PREROUTING" chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingNat = "netbird-rt-postrouting" @@ -91,7 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou var err error r.filterTable, err = r.loadFilterTable() if err != nil { - log.Warnf("failed to load filter table, skipping accept rules: %v", err) + log.Debugf("ip filter table not found: %v", err) } return r, nil @@ -183,6 +187,33 @@ func (r *router) loadFilterTable() (*nftables.Table, error) { return nil, errFilterTableNotFound } +func hookName(hook *nftables.ChainHook) string { + if hook == nil { + return "unknown" + } + switch *hook { + case *nftables.ChainHookForward: + return chainNameForward + case *nftables.ChainHookInput: + return chainNameInput + default: + return fmt.Sprintf("hook(%d)", *hook) + } +} + +func familyName(family nftables.TableFamily) string { + switch family { + case nftables.TableFamilyIPv4: + return "ip" + case nftables.TableFamilyIPv6: + return "ip6" + case nftables.TableFamilyINet: + return "inet" + default: + return fmt.Sprintf("family(%d)", family) + } +} + func (r *router) createContainers() error { r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingFw, @@ -930,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error { // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // This method also adds INPUT chain rules to allow traffic to the local interface. func (r *router) acceptForwardRules() error { + var merr *multierror.Error + + if err := r.acceptFilterTableRules(); err != nil { + merr = multierror.Append(merr, err) + } + + if err := r.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) acceptFilterTableRules() error { if r.filterTable == nil { - log.Debugf("table 'filter' not found for forward rules, skipping accept rules") return nil } @@ -944,11 +988,11 @@ func (r *router) acceptForwardRules() error { // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { - // filter table exists but iptables is not + // iptables is not available but the filter table exists log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" - return r.acceptFilterRulesNftables() + return r.acceptFilterRulesNftables(r.filterTable) } return r.acceptFilterRulesIptables(ipt) @@ -959,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { for _, rule := range r.getAcceptForwardRules() { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err)) } else { log.Debugf("added iptables forward rule: %v", rule) } @@ -967,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { inputRule := r.getAcceptInputRule() if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err)) } else { log.Debugf("added iptables input rule: %v", inputRule) } @@ -987,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string { return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} } -func (r *router) acceptFilterRulesNftables() error { +// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables. +// This is used when iptables is not available. +func (r *router) acceptFilterRulesNftables(table *nftables.Table) error { intf := ifname(r.wgIface.Name()) + forwardChain := &nftables.Chain{ + Name: chainNameForward, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + } + r.insertForwardAcceptRules(forwardChain, intf) + + inputChain := &nftables.Chain{ + Name: chainNameInput, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + } + r.insertInputAcceptRule(inputChain, intf) + + return r.conn.Flush() +} + +// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables). +// It dynamically finds chains at call time to handle chains that may have been created after startup. +func (r *router) acceptExternalChainsRules() error { + chains := r.findExternalChains() + if len(chains) == 0 { + return nil + } + + intf := ifname(r.wgIface.Name()) + + for _, chain := range chains { + if chain.Hooknum == nil { + log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) + continue + } + + log.Debugf("adding accept rules to external %s chain: %s %s/%s", + hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) + + switch *chain.Hooknum { + case *nftables.ChainHookForward: + r.insertForwardAcceptRules(chain, intf) + case *nftables.ChainHookInput: + r.insertInputAcceptRule(chain, intf) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush external chain rules: %w", err) + } + + return nil +} + +func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) { iifRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameForward, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -1021,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error { Data: intf, }, } - oifRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameForward, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } r.conn.InsertRule(oifRule) +} +func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) { inputRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameInput, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookInput, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -1058,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error { UserData: []byte(userDataAcceptInputRule), } r.conn.InsertRule(inputRule) - - return r.conn.Flush() } func (r *router) removeAcceptFilterRules() error { + var merr *multierror.Error + + if err := r.removeFilterTableRules(); err != nil { + merr = multierror.Append(merr, err) + } + + if err := r.removeExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) removeFilterTableRules() error { if r.filterTable == nil { return nil } ipt, err := iptables.New() if err != nil { - log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) - return r.removeAcceptFilterRulesNftables() + log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) + return r.removeAcceptRulesFromTable(r.filterTable) } return r.removeAcceptFilterRulesIptables(ipt) } -func (r *router) removeAcceptFilterRulesNftables() error { - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) +func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { + chains, err := r.conn.ListChainsOfTableFamily(table.Family) if err != nil { return fmt.Errorf("list chains: %v", err) } for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name { + if chain.Table.Name != table.Name { continue } @@ -1091,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error { continue } - rules, err := r.conn.GetRules(r.filterTable, chain) + if err := r.removeAcceptRulesFromChain(table, chain); err != nil { + return err + } + } + + return r.conn.Flush() +} + +func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error { + rules, err := r.conn.GetRules(table, chain) + if err != nil { + return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err) + } + } + } + return nil +} + +// removeExternalChainsRules removes our accept rules from all external chains. +// This is deterministic - it scans for chains at removal time rather than relying on saved state, +// ensuring cleanup works even after a crash or if chains changed. +func (r *router) removeExternalChainsRules() error { + chains := r.findExternalChains() + if len(chains) == 0 { + return nil + } + + for _, chain := range chains { + if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil { + log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err) + } + } + + return r.conn.Flush() +} + +// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks. +// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal). +func (r *router) findExternalChains() []*nftables.Chain { + var chains []*nftables.Chain + + families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} + + for _, family := range families { + allChains, err := r.conn.ListChainsOfTableFamily(family) if err != nil { - return fmt.Errorf("get rules: %v", err) + log.Debugf("list chains for family %d: %v", family, err) + continue } - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("delete rule: %v", err) - } + for _, chain := range allChains { + if r.isExternalChain(chain) { + chains = append(chains, chain) } } } - if err := r.conn.Flush(); err != nil { - return fmt.Errorf(flushError, err) + return chains +} + +func (r *router) isExternalChain(chain *nftables.Chain) bool { + if r.workTable != nil && chain.Table.Name == r.workTable.Name { + return false } - return nil + // Skip all iptables-managed tables in the ip family + if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { + return false + } + + if chain.Type != nftables.ChainTypeFilter { + return false + } + + if chain.Hooknum == nil { + return false + } + + return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput +} + +func isIptablesTable(name string) bool { + switch name { + case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity: + return true + } + return false } func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { @@ -1119,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { for _, rule := range r.getAcceptForwardRules() { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err)) } } inputRule := r.getAcceptInputRule() if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err)) } return nberrors.FormatErrorOrNil(merr) From 71b6855e0920b377edb203999139056833585534 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 4 Dec 2025 19:51:50 +0100 Subject: [PATCH 109/120] [client] Fix engine shutdown deadlock and sync-signal message handling races (#4891) * Fix engine shutdown deadlock and message handling races - Release syncMsgMux before waiting for shutdownWg to prevent deadlock - Check context inside lock in handleSync and receiveSignalEvents - Prevents nil pointer access when messages arrive during engine stop --- client/internal/connect.go | 13 ++++++++----- client/internal/engine.go | 31 +++++++++++++++++++++---------- client/server/server.go | 2 ++ 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 5a5f4f63c..e9d422a28 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -273,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan checks := loginResp.GetChecks() c.engineMutex.Lock() - c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) - c.engine.SetSyncResponsePersistence(c.persistSyncResponse) + engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) + engine.SetSyncResponsePersistence(c.persistSyncResponse) + c.engine = engine c.engineMutex.Unlock() - if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { + if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } @@ -293,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan <-engineCtx.Done() c.engineMutex.Lock() - engine := c.engine c.engine = nil c.engineMutex.Unlock() - if engine != nil && engine.wgInterface != nil { + // todo: consider to remove this condition. Is not thread safe. + // We should always call Stop(), but we need to verify that it is idempotent + if engine.wgInterface != nil { log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) + if err := engine.Stop(); err != nil { log.Errorf("Failed to stop engine: %v", err) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 84bb8beea..ce46dac8c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -280,7 +280,6 @@ func (e *Engine) Stop() error { return nil } e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() if e.connMgr != nil { e.connMgr.Close() @@ -295,7 +294,7 @@ func (e *Engine) Stop() error { if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { log.Info("removing peers before dns") if err := e.removeAllPeers(); err != nil { - return fmt.Errorf("failed to remove all peers: %s", err) + log.Errorf("failed to remove all peers: %s", err) } } if err := e.stopSSHServer(); err != nil { @@ -319,7 +318,7 @@ func (e *Engine) Stop() error { if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" { log.Info("removing peers before routes") if err := e.removeAllPeers(); err != nil { - return fmt.Errorf("failed to remove all peers: %s", err) + log.Errorf("failed to remove all peers: %s", err) } } @@ -338,7 +337,7 @@ func (e *Engine) Stop() error { if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { log.Info("removing peers after dns and routes") if err := e.removeAllPeers(); err != nil { - return fmt.Errorf("failed to remove all peers: %s", err) + log.Errorf("failed to remove all peers: %s", err) } } @@ -353,16 +352,18 @@ func (e *Engine) Stop() error { e.flowManager.Close() } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer stateCancel() - if err := e.stateManager.Stop(ctx); err != nil { - return fmt.Errorf("failed to stop state manager: %w", err) + if err := e.stateManager.Stop(stateCtx); err != nil { + log.Errorf("failed to stop state manager: %v", err) } if err := e.stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } + e.syncMsgMux.Unlock() + timeout := e.calculateShutdownTimeout() log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) @@ -448,8 +449,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) if err != nil { return fmt.Errorf("create rosenpass manager: %w", err) } - err := e.rpManager.Run() - if err != nil { + if err := e.rpManager.Run(); err != nil { return fmt.Errorf("run rosenpass manager: %w", err) } } @@ -501,6 +501,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) } if err := e.createFirewall(); err != nil { + e.close() return err } @@ -766,6 +767,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + // Check context INSIDE lock to ensure atomicity with shutdown + if e.ctx.Err() != nil { + return e.ctx.Err() + } + if update.GetNetbirdConfig() != nil { wCfg := update.GetNetbirdConfig() err := e.updateTURNs(wCfg.GetTurns()) @@ -1385,6 +1391,11 @@ func (e *Engine) receiveSignalEvents() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + // Check context INSIDE lock to ensure atomicity with shutdown + if e.ctx.Err() != nil { + return e.ctx.Err() + } + conn, ok := e.peerStore.PeerConn(msg.Key) if !ok { return fmt.Errorf("wrongly addressed message %s", msg.Key) diff --git a/client/server/server.go b/client/server/server.go index cf79ec5ed..d33595115 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -822,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes defer s.mutex.Unlock() if err := s.cleanupConnection(); err != nil { + // todo review to update the status in case any type of error log.Errorf("failed to shut down properly: %v", err) return nil, err } @@ -914,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe } if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) { + // todo review to update the status in case any type of error log.Errorf("failed to cleanup connection: %v", err) return nil, err } From cb6b086164608d66c1dbaf33adaa2e4d9a6b1905 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 4 Dec 2025 21:01:22 +0100 Subject: [PATCH 110/120] [client] Reorder subsystem shutdown so peer removal goes first (#4914) Remove peers before DNS and routes --- client/internal/engine.go | 42 +++++++++++++-------------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index ce46dac8c..ff1cec19a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -291,21 +291,12 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") - if os.Getenv("NB_REMOVE_BEFORE_DNS") == "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { - log.Info("removing peers before dns") - if err := e.removeAllPeers(); err != nil { - log.Errorf("failed to remove all peers: %s", err) - } - } if err := e.stopSSHServer(); err != nil { log.Warnf("failed to stop SSH server: %v", err) } e.cleanupSSHConfig() - // stop/restore DNS first so dbus and friends don't complain because of a missing interface - e.stopDNSServer() - if e.ingressGatewayMgr != nil { if err := e.ingressGatewayMgr.Close(); err != nil { log.Warnf("failed to cleanup forward rules: %v", err) @@ -313,33 +304,28 @@ func (e *Engine) Stop() error { e.ingressGatewayMgr = nil } - e.stopDNSForwarder() + if e.srWatcher != nil { + e.srWatcher.Close() + } - if os.Getenv("NB_REMOVE_BEFORE_ROUTES") == "true" && os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" { - log.Info("removing peers before routes") - if err := e.removeAllPeers(); err != nil { - log.Errorf("failed to remove all peers: %s", err) - } + log.Info("cleaning up status recorder states") + e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) + e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) + e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) + + if err := e.removeAllPeers(); err != nil { + log.Errorf("failed to remove all peers: %s", err) } if e.routeManager != nil { e.routeManager.Stop(e.stateManager) } - if e.srWatcher != nil { - e.srWatcher.Close() - } - log.Info("cleaning up status recorder states") - e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) - e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) - e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) + e.stopDNSForwarder() - if os.Getenv("NB_REMOVE_BEFORE_DNS") != "true" && os.Getenv("NB_REMOVE_BEFORE_ROUTES") != "true" { - log.Info("removing peers after dns and routes") - if err := e.removeAllPeers(); err != nil { - log.Errorf("failed to remove all peers: %s", err) - } - } + // stop/restore DNS after peers are closed but before interface goes down + // so dbus and friends don't complain because of a missing interface + e.stopDNSServer() if e.cancel != nil { e.cancel() From f538e6e9ae215d1e4baf48f93bbe6fbcb0f1e303 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:29:27 +0100 Subject: [PATCH 111/120] [client] Use setsid to avoid the parent process from being killed via HUP by login (#4900) --- client/ssh/server/command_execution_js.go | 5 +++ client/ssh/server/command_execution_unix.go | 26 ++++++++++++- .../ssh/server/command_execution_windows.go | 5 +++ client/ssh/server/server.go | 4 +- client/ssh/server/userswitching_unix.go | 39 ++++++++++++++++--- 5 files changed, 71 insertions(+), 8 deletions(-) diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go index 6473f8273..01759a337 100644 --- a/client/ssh/server/command_execution_js.go +++ b/client/ssh/server/command_execution_js.go @@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool { return false } +// detectUtilLinuxLogin always returns false on JS/WASM +func (s *Server) detectUtilLinuxLogin(context.Context) bool { + return false +} + // executeCommandWithPty is not supported on JS/WASM func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { logger.Errorf("PTY command execution not supported on JS/WASM") diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go index da059fed9..db1a9bcfe 100644 --- a/client/ssh/server/command_execution_unix.go +++ b/client/ssh/server/command_execution_unix.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "os/user" + "runtime" "strings" "sync" "syscall" @@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool { return supported } +// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils). +// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent. +// See https://bugs.debian.org/1078023 for details. +func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool { + if runtime.GOOS != "linux" { + return false + } + + ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + cmd := exec.CommandContext(ctx, "login", "--version") + output, err := cmd.CombinedOutput() + if err != nil { + log.Debugf("login --version failed (likely shadow-utils): %v", err) + return false + } + + isUtilLinux := strings.Contains(string(output), "util-linux") + log.Debugf("util-linux login detected: %v", isUtilLinux) + return isUtilLinux +} + // createSuCommand creates a command using su -l -c for privilege switching func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { suPath, err := exec.LookPath("su") @@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu return false } - logger.Infof("starting interactive shell: %s", execCmd.Path) + logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " ")) return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) } diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go index 37b3ae0ee..998796871 100644 --- a/client/ssh/server/command_execution_windows.go +++ b/client/ssh/server/command_execution_windows.go @@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool { return false } +// detectUtilLinuxLogin always returns false on Windows +func (s *Server) detectUtilLinuxLogin(context.Context) bool { + return false +} + // executeCommandWithPty executes a command with PTY allocation on Windows using ConPty func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { command := session.RawCommand() diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 44612532b..37763ee0e 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -138,7 +138,8 @@ type Server struct { jwtExtractor *jwt.ClaimsExtractor jwtConfig *JWTConfig - suSupportsPty bool + suSupportsPty bool + loginIsUtilLinux bool } type JWTConfig struct { @@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { } s.suSupportsPty = s.detectSuPtySupport(ctx) + s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx) ln, addrDesc, err := s.createListener(ctx, addr) if err != nil { diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index 06fefabd7..bc1557419 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st switch runtime.GOOS { case "linux": - // Special handling for Arch Linux without /etc/pam.d/remote - if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") { - return loginPath, []string{"-f", username, "-p"}, nil - } - return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil + p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String()) + return p, a, nil case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly": return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil default: @@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st } } -// fileExists checks if a file exists (helper for login command logic) +// getLinuxLoginCmd returns the login command for Linux systems. +// Handles differences between util-linux and shadow-utils login implementations. +func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) { + // Special handling for Arch Linux without /etc/pam.d/remote + var loginArgs []string + if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") { + loginArgs = []string{"-f", username, "-p"} + } else { + loginArgs = []string{"-f", username, "-h", remoteIP, "-p"} + } + + // util-linux login requires setsid -c to create a new session and set the + // controlling terminal. Without this, vhangup() kills the parent process. + // See https://bugs.debian.org/1078023 for details. + // TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec() + // to avoid external setsid dependency. + if !s.loginIsUtilLinux { + return loginPath, loginArgs + } + + setsidPath, err := exec.LookPath("setsid") + if err != nil { + log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err) + return loginPath, loginArgs + } + + args := append([]string{"-w", "-c", loginPath}, loginArgs...) + return setsidPath, args +} + +// fileExists checks if a file exists func (s *Server) fileExists(path string) bool { _, err := os.Stat(path) return err == nil From 3f4f825ec14c45dd149a743eb80e98e7bb5e9ec5 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 5 Dec 2025 17:42:49 +0100 Subject: [PATCH 112/120] [client] Fix DNS forwarder returning broken records on 4 to 6 mapped IP addresses (#4887) --- client/internal/dnsfwd/forwarder.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index aef16a8cf..6b8042ccb 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns return nil } + // Unmap IPv4-mapped IPv6 addresses that some resolvers may return + for i, ip := range ips { + ips[i] = ip.Unmap() + } + f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) f.cache.set(domain, question.Qtype, ips) From 44851e06fbaaf02f2cde642c40ff6cecea0a3c3f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 10 Dec 2025 19:26:51 +0100 Subject: [PATCH 113/120] [management] cleanup logs (#4933) --- management/internals/shared/grpc/server.go | 29 +-------- management/internals/shared/grpc/token_mgr.go | 6 +- .../server/http/middleware/auth_middleware.go | 2 +- management/server/peer.go | 5 +- management/server/posture/os_version.go | 4 +- management/server/settings/manager.go | 8 --- management/server/store/sql_store.go | 63 +------------------ 7 files changed, 9 insertions(+), 108 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 62dc215d8..16950db5e 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -134,10 +134,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser } log.WithContext(ctx).Tracef("GetServerKey request from %s", ip) - start := time.Now() - defer func() { - log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start)) - }() // todo introduce something more meaningful with the key expiration/rotation if s.appMetrics != nil { @@ -222,8 +218,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S return err } - log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) - // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) @@ -235,7 +229,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S } }() log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start)) - log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart)) log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) @@ -352,7 +345,7 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) + log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key) } func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { @@ -561,16 +554,10 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) - log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) - defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) } - took := time.Since(reqStart) - if took > 7*time.Second { - log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart)) - } }() if loginReq.GetMeta() == nil { @@ -604,16 +591,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto return nil, mapError(ctx, err) } - log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart)) - loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) if err != nil { log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) return nil, status.Errorf(codes.Internal, "failed logging in peer") } - log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart)) - key, err := s.secretsManager.GetWGKey() if err != nil { log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err) @@ -730,12 +713,10 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer return status.Errorf(codes.Internal, "error handling request") } - sendStart := time.Now() err = srv.Send(&proto.EncryptedMessage{ WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) - log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart)) if err != nil { log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) @@ -750,10 +731,6 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer // which will be used by our clients to Login func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) - start := time.Now() - defer func() { - log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start)) - }() peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { @@ -813,10 +790,6 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr // which will be used by our clients to Login func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) - start := time.Now() - defer func() { - log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start)) - }() peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { diff --git a/management/internals/shared/grpc/token_mgr.go b/management/internals/shared/grpc/token_mgr.go index 0f893ae3a..ccb32202f 100644 --- a/management/internals/shared/grpc/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -167,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI relayCancel := make(chan struct{}, 1) m.relayCancelMap[peerID] = relayCancel go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel) - log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID) + log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID) } } @@ -178,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc for { select { case <-cancel: - log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID) + log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID) return case <-ticker.C: m.pushNewTURNAndRelayTokens(ctx, accountID, peerID) @@ -193,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac for { select { case <-cancel: - log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID) + log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID) return case <-ticker.C: m.pushNewRelayTokens(ctx, accountID, peerID) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index ffd7e0b93..38cf0c290 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -141,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] } if userAuth.AccountId != accountId { - log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) + log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId) userAuth.AccountId = accountId } diff --git a/management/server/peer.go b/management/server/peer.go index f2de05f15..49f5bf2a5 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -172,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio } } - log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) + log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected) err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus) if err != nil { @@ -783,7 +783,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } - startTransaction := time.Now() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { @@ -853,8 +852,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } - log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) - if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) if err != nil { diff --git a/management/server/posture/os_version.go b/management/server/posture/os_version.go index 411f4c2c6..2ef97a066 100644 --- a/management/server/posture/os_version.go +++ b/management/server/posture/os_version.go @@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error { func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) { if check == nil { - log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS) return false, nil } @@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) { if check == nil { - log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS) return false, nil } diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index f16b609f8..2b2896572 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -5,9 +5,6 @@ package settings import ( "context" "fmt" - "time" - - log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" @@ -48,11 +45,6 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager { } func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start)) - }() - if userID != activity.SystemInitiator { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 94b7fc1cc..4aa3de499 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -27,7 +27,6 @@ import ( "gorm.io/gorm/logger" nbdns "github.com/netbirdio/netbird/dns" - nbcontext "github.com/netbirdio/netbird/management/server/context" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -288,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) } - log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds()) + log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds()) return err } @@ -583,9 +582,6 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren } func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2152,9 +2148,6 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2171,9 +2164,6 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt } func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2229,9 +2219,6 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - var user types.User result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { @@ -2491,9 +2478,6 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s } func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2514,9 +2498,6 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - result := s.db.WithContext(ctx).Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ @@ -2537,9 +2518,6 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - var groupID string _ = s.db.WithContext(ctx).Model(types.Group{}). Select("id"). @@ -2569,9 +2547,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer // AddPeerToGroup adds a peer to a group func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - peer := &types.GroupPeer{ AccountID: accountID, GroupID: groupID, @@ -2768,9 +2743,6 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -2897,9 +2869,6 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - ctx, cancel := getDebuggingCtx(ctx) - defer cancel() - result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) @@ -4022,36 +3991,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin return groupPeers, nil } -func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string) - if ok { - //nolint - ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID) - } - - requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string) - if ok { - //nolint - ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID) - } - - accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string) - if ok { - //nolint - ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID) - } - - go func() { - select { - case <-ctx.Done(): - case <-grpcCtx.Done(): - log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err()) - } - }() - return ctx, cancel -} - func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { var info types.PrimaryAccountInfo result := s.db.Model(&types.Account{}). From 94d34dc0c5cc9ea353438adfbe4fd613566be2b2 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:29:15 +0100 Subject: [PATCH 114/120] [management] monitoring updates (#4937) --- management/internals/shared/grpc/server.go | 13 ++++--------- management/server/account.go | 7 +++++++ management/server/telemetry/grpc_metrics.go | 18 ------------------ .../server/telemetry/http_api_metrics.go | 12 ++++++++++++ 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 16950db5e..6029dc8bf 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -171,8 +171,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S } s.syncSem.Add(1) - reqStart := time.Now() - ctx := srv.Context() syncReq := &proto.SyncRequest{} @@ -190,7 +188,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() } if s.logBlockedPeers { - log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) + log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) } if s.blockPeersWithSameConfig { s.syncSem.Add(-1) @@ -263,10 +261,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) - } - unlock() unlock = nil @@ -518,7 +512,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto reqStart := time.Now() realIP := getRealIP(ctx) sRealIP := realIP.String() - log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) loginReq := &proto.LoginRequest{} peerKey, err := s.parseRequest(ctx, req, loginReq) @@ -530,7 +523,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto metahashed := metaHash(peerMeta, sRealIP) if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.logBlockedPeers { - log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) + log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) } if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestBlocked() @@ -554,6 +547,8 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) + defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) diff --git a/management/server/account.go b/management/server/account.go index dac040db0..8a2155cb4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -787,6 +787,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) accountIDString := fmt.Sprintf("%v", accountID) + if ctx == nil { + ctx = context.Background() + } + + // nolint:staticcheck + ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString) if err != nil { return nil, nil, err diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index d4301802f..4ba1ee14e 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -16,7 +16,6 @@ type GRPCMetrics struct { meter metric.Meter syncRequestsCounter metric.Int64Counter syncRequestsBlockedCounter metric.Int64Counter - syncRequestHighLatencyCounter metric.Int64Counter loginRequestsCounter metric.Int64Counter loginRequestsBlockedCounter metric.Int64Counter loginRequestHighLatencyCounter metric.Int64Counter @@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } - syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter", - metric.WithUnit("1"), - metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"), - ) - if err != nil { - return nil, err - } - loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), @@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro meter: meter, syncRequestsCounter: syncRequestsCounter, syncRequestsBlockedCounter: syncRequestsBlockedCounter, - syncRequestHighLatencyCounter: syncRequestHighLatencyCounter, loginRequestsCounter: loginRequestsCounter, loginRequestsBlockedCounter: loginRequestsBlockedCounter, loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, @@ -172,14 +162,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration } } -// CountSyncRequestDuration counts the duration of the sync gRPC requests -func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { - grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) - if duration > HighLatencyThreshold { - grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) - } -} - // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64) error { _, err := grpcMetrics.meter.RegisterCallback( diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index ae27466d9..0b6f8beb6 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { h.ServeHTTP(w, r.WithContext(ctx)) + userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) + if err == nil { + if userAuth.AccountId != "" { + //nolint + ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId) + } + if userAuth.UserId != "" { + //nolint + ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId) + } + } + if w.Status() > 399 { log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) } else { From 90e3b8009fb5228a8643143d775eb141f4abba91 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 11 Dec 2025 20:11:12 +0100 Subject: [PATCH 115/120] [management] Fix sync metrics (#4939) --- management/internals/shared/grpc/server.go | 6 ++++++ management/server/telemetry/grpc_metrics.go | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 6029dc8bf..462e2e6eb 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -171,6 +171,8 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S } s.syncSem.Add(1) + reqStart := time.Now() + ctx := srv.Context() syncReq := &proto.SyncRequest{} @@ -261,6 +263,10 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) + } + unlock() unlock = nil diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index 4ba1ee14e..bd7fbc235 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -162,6 +162,11 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration } } +// CountSyncRequestDuration counts the duration of the sync gRPC requests +func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { + grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) +} + // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64) error { _, err := grpcMetrics.meter.RegisterCallback( From abcbde26f9294f5bd649386530c21d816bca02be Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:45:47 +0100 Subject: [PATCH 116/120] [management] remove context from store methods (#4940) --- management/server/store/sql_store.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4aa3de499..74b03ce48 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -588,7 +588,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } var user types.User - result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID) + result := tx.Take(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -2154,7 +2154,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt } var accountNetwork types.AccountNetwork - if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { + if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -2170,7 +2170,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } var peer nbpeer.Peer - result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey) + result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -2220,7 +2220,7 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user types.User - result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID) + result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) @@ -2484,7 +2484,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } var setupKey types.SetupKey - result := tx.WithContext(ctx). + result := tx. Take(&setupKey, GetKeyQueryCondition(s), key) if result.Error != nil { @@ -2498,7 +2498,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - result := s.db.WithContext(ctx).Model(&types.SetupKey{}). + result := s.db.Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ "used_times": gorm.Expr("used_times + 1"), @@ -2519,7 +2519,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { var groupID string - _ = s.db.WithContext(ctx).Model(types.Group{}). + _ = s.db.Model(types.Group{}). Select("id"). Where("account_id = ? AND name = ?", accountID, "All"). Limit(1). @@ -2743,7 +2743,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { + if err := s.db.Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -2869,7 +2869,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") @@ -4030,7 +4030,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i Network: &types.Network{Net: ipNet}, } - result := s.db.WithContext(ctx). + result := s.db. Model(&types.Account{}). Where(idQueryCondition, accountID). Updates(&patch) From 932c02eaabbcc3d314013969d85cd1aee353af7b Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 12 Dec 2025 18:49:57 +0300 Subject: [PATCH 117/120] [management] Approve all pending peers when peer approval is disabled (#4806) --- go.mod | 2 +- go.sum | 4 +- management/server/account.go | 29 +++---- management/server/account_test.go | 37 +++++++++ management/server/integrated_validator.go | 2 +- .../integrated_validator/interface.go | 2 +- management/server/store/sql_store.go | 12 +++ management/server/store/sql_store_test.go | 77 +++++++++++++++++++ management/server/store/store.go | 1 + 9 files changed, 148 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index 870118c88..8f4ec530b 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba + github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 96a303f79..f10e1e6da 100644 --- a/go.sum +++ b/go.sum @@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/account.go b/management/server/account.go index 8a2155cb4..a9becc4b6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -295,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return err } - if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil { + if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil { return err } + if oldSettings.Extra != nil && newSettings.Extra != nil && + oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled { + approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to approve pending peers: %w", err) + } + + if approvedCount > 0 { + log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID) + updateAccountPeers = true + } + } + if oldSettings.NetworkRange != newSettings.NetworkRange { if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil { return err @@ -372,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return newSettings, nil } -func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { +func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -386,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } - peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") - if err != nil { - return err - } - - peersMap := make(map[string]*nbpeer.Peer, len(peers)) - for _, peer := range peers { - peersMap[peer.ID] = peer - } - - return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID) + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID) } func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 8569f1b2f..7f125e3a0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2058,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") } +func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) { + manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + accountID := account.Id + userID := account.Users[account.CreatedBy].Id + ctx := context.Background() + + newSettings := account.Settings.Copy() + newSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: true, + } + _, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings) + require.NoError(t, err) + + peer1.Status.RequiresApproval = true + peer2.Status.RequiresApproval = true + peer3.Status.RequiresApproval = false + + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1)) + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2)) + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3)) + + newSettings = account.Settings.Copy() + newSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: false, + } + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings) + require.NoError(t, err) + + accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + + for _, peer := range accountPeers { + assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID) + } +} + func TestAccount_GetExpiredPeers(t *testing.T) { type test struct { name string diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index e9a1c8701..69ea668ad 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -127,7 +127,7 @@ type MockIntegratedValidator struct { ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) } -func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { +func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error { return nil } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index 26c338cb6..326fbfaf0 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -10,7 +10,7 @@ import ( // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { - ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 74b03ce48..2b8981b97 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -412,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW return nil } +// ApproveAccountPeers marks all peers that currently require approval in the given account as approved. +func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) { + result := s.db.Model(&nbpeer.Peer{}). + Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true). + Update("peer_status_requires_approval", false) + if result.Error != nil { + return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error) + } + + return int(result.RowsAffected), nil +} + // SaveUsers saves the given list of users to the database. func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { if len(users) == 0 { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index d40c4664c..2e2623910 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { }) } } + +func TestSqlStore_ApproveAccountPeers(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + accountID := "test-account" + ctx := context.Background() + + account := newAccountWithId(ctx, accountID, "testuser", "example.com") + err := store.SaveAccount(ctx, account) + require.NoError(t, err) + + peers := []*nbpeer.Peer{ + { + ID: "peer1", + AccountID: accountID, + DNSLabel: "peer1.netbird.cloud", + Key: "peer1-key", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: true, + LastSeen: time.Now().UTC(), + }, + }, + { + ID: "peer2", + AccountID: accountID, + DNSLabel: "peer2.netbird.cloud", + Key: "peer2-key", + IP: net.ParseIP("100.64.0.2"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: true, + LastSeen: time.Now().UTC(), + }, + }, + { + ID: "peer3", + AccountID: accountID, + DNSLabel: "peer3.netbird.cloud", + Key: "peer3-key", + IP: net.ParseIP("100.64.0.3"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: false, + LastSeen: time.Now().UTC(), + }, + }, + } + + for _, peer := range peers { + err = store.AddPeerToAccount(ctx, peer) + require.NoError(t, err) + } + + t.Run("approve all pending peers", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, accountID) + require.NoError(t, err) + assert.Equal(t, 2, count) + + allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + + for _, peer := range allPeers { + assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID) + } + }) + + t.Run("no peers to approve", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, accountID) + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + + t.Run("non-existent account", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, "non-existent") + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 007e2b739..0ec7949f9 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -143,6 +143,7 @@ type Store interface { SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error + ApproveAccountPeers(ctx context.Context, accountID string) (int, error) DeletePeer(ctx context.Context, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) From 08f31fbcb3198f309fb52a70a48d99d47ca35b83 Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Fri, 12 Dec 2025 14:29:58 -0300 Subject: [PATCH 118/120] [iOS] Add force relay connection on iOS (#4928) * [ios] Add a bogus test to check iOS behavior when setting environment variables * [ios] Revert "Add a bogus test to check iOS behavior when setting environment variables" This reverts commit 90ca01105a6b0f4471aac07a63fc95e5d4eaef9b. * [ios] Add EnvList struct to export and import environment variables * [ios] Add envList parameter to the iOS Client Run method * [ios] Add some debug logging to exportEnvVarList * Add "//go:build ios" to client/ios/NetBirdSDK files --- client/ios/NetBirdSDK/client.go | 22 ++++++++++++++- client/ios/NetBirdSDK/env_list.go | 34 +++++++++++++++++++++++ client/ios/NetBirdSDK/gomobile.go | 2 ++ client/ios/NetBirdSDK/logger.go | 2 ++ client/ios/NetBirdSDK/login.go | 2 ++ client/ios/NetBirdSDK/peer_notifier.go | 2 ++ client/ios/NetBirdSDK/preferences.go | 2 ++ client/ios/NetBirdSDK/preferences_test.go | 2 ++ client/ios/NetBirdSDK/routes.go | 2 ++ 9 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 client/ios/NetBirdSDK/env_list.go diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 6d969bb12..463c93d57 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -1,9 +1,12 @@ +//go:build ios + package NetBirdSDK import ( "context" "fmt" "net/netip" + "os" "sort" "strings" "sync" @@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s } // Run start the internal client. It is a blocker function -func (c *Client) Run(fd int32, interfaceName string) error { +func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { + exportEnvList(envList) log.Infof("Starting NetBird client") log.Debugf("Tunnel uses interface: %s", interfaceName) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ @@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID { } return netIDs } + +func exportEnvList(list *EnvList) { + if list == nil { + return + } + for k, v := range list.AllItems() { + log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k)) + log.Debugf("Setting env variable %s: %s", k, v) + + if err := os.Setenv(k, v); err != nil { + log.Errorf("could not set env variable %s: %v", k, err) + } else { + log.Debugf("Env variable %s was set successfully", k) + } + } +} diff --git a/client/ios/NetBirdSDK/env_list.go b/client/ios/NetBirdSDK/env_list.go new file mode 100644 index 000000000..4800803d7 --- /dev/null +++ b/client/ios/NetBirdSDK/env_list.go @@ -0,0 +1,34 @@ +//go:build ios + +package NetBirdSDK + +import "github.com/netbirdio/netbird/client/internal/peer" + +// EnvList is an exported struct to be bound by gomobile +type EnvList struct { + data map[string]string +} + +// NewEnvList creates a new EnvList +func NewEnvList() *EnvList { + return &EnvList{data: make(map[string]string)} +} + +// Put adds a key-value pair +func (el *EnvList) Put(key, value string) { + el.data[key] = value +} + +// Get retrieves a value by key +func (el *EnvList) Get(key string) string { + return el.data[key] +} + +func (el *EnvList) AllItems() map[string]string { + return el.data +} + +// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client +func GetEnvKeyNBForceRelay() string { + return peer.EnvKeyNBForceRelay +} diff --git a/client/ios/NetBirdSDK/gomobile.go b/client/ios/NetBirdSDK/gomobile.go index 9eadd6a7f..79bf0c2ac 100644 --- a/client/ios/NetBirdSDK/gomobile.go +++ b/client/ios/NetBirdSDK/gomobile.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import _ "golang.org/x/mobile/bind" diff --git a/client/ios/NetBirdSDK/logger.go b/client/ios/NetBirdSDK/logger.go index f1ad1b9f6..531d0ba89 100644 --- a/client/ios/NetBirdSDK/logger.go +++ b/client/ios/NetBirdSDK/logger.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go index 570c44f80..1c2b38a61 100644 --- a/client/ios/NetBirdSDK/login.go +++ b/client/ios/NetBirdSDK/login.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/peer_notifier.go b/client/ios/NetBirdSDK/peer_notifier.go index 16c5039eb..9b00568be 100644 --- a/client/ios/NetBirdSDK/peer_notifier.go +++ b/client/ios/NetBirdSDK/peer_notifier.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK // PeerInfo describe information about the peers. It designed for the UI usage diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index 5e7050465..39ae06538 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go index 780443a7b..5f75e7c9a 100644 --- a/client/ios/NetBirdSDK/preferences_test.go +++ b/client/ios/NetBirdSDK/preferences_test.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK import ( diff --git a/client/ios/NetBirdSDK/routes.go b/client/ios/NetBirdSDK/routes.go index 30d0d0d0a..7b84d6e1c 100644 --- a/client/ios/NetBirdSDK/routes.go +++ b/client/ios/NetBirdSDK/routes.go @@ -1,3 +1,5 @@ +//go:build ios + package NetBirdSDK // RoutesSelectionInfoCollection made for Java layer to get non default types as collection From 5748bdd64edf0a0e77dab2311ceb700d049730c8 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 15 Dec 2025 10:28:25 +0100 Subject: [PATCH 119/120] Add health-check agent recognition to avoid error logs (#4917) Health-check connections now send a properly formatted auth message with a well-known peer ID instead of immediately closing. The server recognizes this peer ID and handles the connection gracefully with a debug log instead of error logs. --- relay/healthcheck/peerid/peerid.go | 31 ++++++++++++++++++++++++++++++ relay/healthcheck/ws.go | 15 ++++++++++++++- relay/server/handshake.go | 4 ++-- relay/server/relay.go | 7 ++++++- 4 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 relay/healthcheck/peerid/peerid.go diff --git a/relay/healthcheck/peerid/peerid.go b/relay/healthcheck/peerid/peerid.go new file mode 100644 index 000000000..cd8696817 --- /dev/null +++ b/relay/healthcheck/peerid/peerid.go @@ -0,0 +1,31 @@ +package peerid + +import ( + "crypto/sha256" + + v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2" + "github.com/netbirdio/netbird/shared/relay/messages" +) + +var ( + // HealthCheckPeerID is the hashed peer ID for health check connections + HealthCheckPeerID = messages.HashID("healthcheck-agent") + + // DummyAuthToken is a structurally valid auth token for health check. + // The signature is not valid but the format is correct (1 byte algo + 32 bytes signature + payload). + DummyAuthToken = createDummyToken() +) + +func createDummyToken() []byte { + token := v2.Token{ + AuthAlgo: v2.AuthAlgoHMACSHA256, + Signature: make([]byte, sha256.Size), + Payload: []byte("healthcheck"), + } + return token.Marshal() +} + +// IsHealthCheck checks if the given peer ID is the health check agent +func IsHealthCheck(peerID *messages.PeerID) bool { + return peerID != nil && *peerID == HealthCheckPeerID +} diff --git a/relay/healthcheck/ws.go b/relay/healthcheck/ws.go index db61ed802..9267096f5 100644 --- a/relay/healthcheck/ws.go +++ b/relay/healthcheck/ws.go @@ -7,8 +7,10 @@ import ( "github.com/coder/websocket" + "github.com/netbirdio/netbird/relay/healthcheck/peerid" "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay" + "github.com/netbirdio/netbird/shared/relay/messages" ) func dialWS(ctx context.Context, address url.URL) error { @@ -30,7 +32,18 @@ func dialWS(ctx context.Context, address url.URL) error { if err != nil { return fmt.Errorf("failed to connect to websocket: %w", err) } + defer func() { + _ = conn.CloseNow() + }() + + authMsg, err := messages.MarshalAuthMsg(peerid.HealthCheckPeerID, peerid.DummyAuthToken) + if err != nil { + return fmt.Errorf("failed to marshal auth message: %w", err) + } + + if err := conn.Write(ctx, websocket.MessageBinary, authMsg); err != nil { + return fmt.Errorf("failed to write auth message: %w", err) + } - _ = conn.Close(websocket.StatusNormalClosure, "availability check complete") return nil } diff --git a/relay/server/handshake.go b/relay/server/handshake.go index 922369798..8c3ee1899 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -97,7 +97,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) { return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) } if err != nil { - return nil, err + return peerID, err } h.peerID = peerID return peerID, nil @@ -147,7 +147,7 @@ func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) { } if err := h.validator.Validate(authPayload); err != nil { - return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) + return rawPeerID, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) } return rawPeerID, nil diff --git a/relay/server/relay.go b/relay/server/relay.go index c1cfa13fd..bb355f58f 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/metric" + "github.com/netbirdio/netbird/relay/healthcheck/peerid" //nolint:staticcheck "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" @@ -123,7 +124,11 @@ func (r *Relay) Accept(conn net.Conn) { } peerID, err := h.handshakeReceive() if err != nil { - log.Errorf("failed to handshake: %s", err) + if peerid.IsHealthCheck(peerID) { + log.Debugf("health check connection from %s", conn.RemoteAddr()) + } else { + log.Errorf("failed to handshake: %s", err) + } if cErr := conn.Close(); cErr != nil { log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) } From 447cd287f5ef17c12242125ed154ddfa1dd82acf Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 15 Dec 2025 10:34:48 +0100 Subject: [PATCH 120/120] [ci] Add local lint setup with pre-push hook to catch issues early (#4925) * Add local lint setup with pre-push hook to catch issues early Developers can now catch lint issues before pushing, reducing CI failures and iteration time. The setup uses golangci-lint locally with the same configuration as CI. Setup: - Run `make setup-hooks` once after cloning - Pre-push hook automatically lints changed files (~90s) - Use `make lint` to manually check changed files - Use `make lint-all` to run full CI-equivalent lint The Makefile auto-installs golangci-lint to ./bin/ using go install to match the Go version in go.mod, avoiding version compatibility issues. --------- Co-authored-by: mlsmaycon --- .githooks/pre-push | 11 +++++++++++ CONTRIBUTING.md | 8 ++++++++ Makefile | 27 +++++++++++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100755 .githooks/pre-push create mode 100644 Makefile diff --git a/.githooks/pre-push b/.githooks/pre-push new file mode 100755 index 000000000..31898182e --- /dev/null +++ b/.githooks/pre-push @@ -0,0 +1,11 @@ +#!/bin/bash + +echo "Running pre-push hook..." +if ! make lint; then + echo "" + echo "Hint: To push without verification, run:" + echo " git push --no-verify" + exit 1 +fi + +echo "All checks passed!" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c82cfc763..efc7d9460 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -136,6 +136,14 @@ checked out and set up: go mod tidy ``` +6. Configure Git hooks for automatic linting: + + ```bash + make setup-hooks + ``` + + This will configure Git to run linting automatically before each push, helping catch issues early. + ### Dev Container Support If you prefer using a dev container for development, NetBird now includes support for dev containers. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..43379e115 --- /dev/null +++ b/Makefile @@ -0,0 +1,27 @@ +.PHONY: lint lint-all lint-install setup-hooks +GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint + +# Install golangci-lint locally if needed +$(GOLANGCI_LINT): + @echo "Installing golangci-lint..." + @mkdir -p ./bin + @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + +# Lint only changed files (fast, for pre-push) +lint: $(GOLANGCI_LINT) + @echo "Running lint on changed files..." + @$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m + +# Lint entire codebase (slow, matches CI) +lint-all: $(GOLANGCI_LINT) + @echo "Running lint on all files..." + @$(GOLANGCI_LINT) run --timeout=12m + +# Just install the linter +lint-install: $(GOLANGCI_LINT) + +# Setup git hooks for all developers +setup-hooks: + @git config core.hooksPath .githooks + @chmod +x .githooks/pre-push + @echo "✅ Git hooks configured! Pre-push will now run 'make lint'"