From 4fad0e521f59ffdbccb26c10e5b0601e4e57f2cb Mon Sep 17 00:00:00 2001 From: benniekiss <63211101+benniekiss@users.noreply.github.com> Date: Tue, 16 Jul 2024 14:44:21 -0400 Subject: [PATCH 01/53] Support custom SSL certificates for the signal service (#2257) --- signal/README.md | 2 + signal/cmd/run.go | 94 ++++++++++++++++++++++++++++++++++------------- 2 files changed, 70 insertions(+), 26 deletions(-) diff --git a/signal/README.md b/signal/README.md index 2da47283e..9e3207cfa 100644 --- a/signal/README.md +++ b/signal/README.md @@ -18,6 +18,8 @@ Flags: --letsencrypt-domain string a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS --port int Server port to listen on (e.g. 10000) (default 10000) --ssl-dir string server ssl directory location. *Required only for Let's Encrypt certificates. (default "/var/lib/netbird/") + --cert-file string Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect + --cert-key string Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect Global Flags: --log-file string sets Netbird log path. If console is specified the the log will be output to stdout (default "/var/log/netbird/signal.log") diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 4b0dc583e..8f75c1e04 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "crypto/tls" "errors" "flag" "fmt" @@ -41,7 +42,8 @@ var ( signalLetsencryptDomain string signalSSLDir string defaultSignalSSLDir string - tlsEnabled bool + signalCertFile string + signalCertKey string signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, @@ -59,9 +61,13 @@ var ( Use: "run", Short: "start NetBird Signal Server daemon", PreRun: func(cmd *cobra.Command, args []string) { + flag.Parse() + // detect whether user specified a port userPort := cmd.Flag("port").Changed - if signalLetsencryptDomain != "" { + + tlsEnabled := false + if signalLetsencryptDomain != "" || (signalCertFile != "" && signalCertKey != "") { tlsEnabled = true } @@ -93,14 +99,24 @@ var ( var opts []grpc.ServerOption var certManager *autocert.Manager - if tlsEnabled { - // Let's encrypt enabled -> generate certificate automatically + var tlsConfig *tls.Config + if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { return err } transportCredentials := credentials.NewTLS(certManager.TLSConfig()) opts = append(opts, grpc.Creds(transportCredentials)) + log.Infof("setting up TLS with LetsEncrypt.") + } else if signalCertFile != "" && signalCertKey != "" { + tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) + if err != nil { + log.Errorf("cannot load TLS credentials: %v", err) + return err + } + transportCredentials := credentials.NewTLS(tlsConfig) + opts = append(opts, grpc.Creds(transportCredentials)) + log.Infof("setting up TLS with custom certificates.") } metricsServer := metrics.NewServer(metricsPort, "") @@ -124,7 +140,34 @@ var ( } proto.RegisterSignalExchangeServer(grpcServer, srv) + grpcRootHandler := grpcHandlerFunc(grpcServer) var compatListener net.Listener + var grpcListener net.Listener + var httpListener net.Listener + + if certManager != nil { + // a call to certManager.Listener() always creates a new listener so we do it once + httpListener := certManager.Listener() + if signalPort == 443 { + // running gRPC and HTTP cert manager on the same port + serveHTTP(httpListener, certManager.HTTPHandler(grpcRootHandler)) + log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String()) + } else { + // Start the HTTP cert manager server separately + serveHTTP(httpListener, certManager.HTTPHandler(nil)) + log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String()) + } + } + + // If certManager is configured and signalPort == 443, then the gRPC server has already been started + if certManager == nil || signalPort != 443 { + grpcListener, err = serveGRPC(grpcServer, signalPort) + if err != nil { + return err + } + log.Infof("running gRPC server: %s", grpcListener.Addr().String()) + } + if signalPort != 10000 { // The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal // are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000. @@ -135,28 +178,6 @@ var ( log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } - var grpcListener net.Listener - var httpListener net.Listener - if tlsEnabled { - httpListener = certManager.Listener() - if signalPort == 443 { - // running gRPC and HTTP cert manager on the same port - serveHTTP(httpListener, certManager.HTTPHandler(grpcHandlerFunc(grpcServer))) - log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String()) - } else { - serveHTTP(httpListener, certManager.HTTPHandler(nil)) - log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String()) - } - } - - if signalPort != 443 || !tlsEnabled { - grpcListener, err = serveGRPC(grpcServer, signalPort) - if err != nil { - return err - } - log.Infof("running gRPC server: %s", grpcListener.Addr().String()) - } - log.Infof("signal server version %s", version.NetbirdVersion()) log.Infof("started Signal Service") @@ -232,6 +253,25 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) { return listener, nil } +func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) { + // Load server's certificate and private key + serverCert, err := tls.LoadX509KeyPair(certFile, certKey) + if err != nil { + return nil, err + } + + // NewDefaultAppMetrics the credentials and return it + config := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.NoClientCert, + NextProtos: []string{ + "h2", "http/1.1", // enable HTTP/2 + }, + } + + return config, nil +} + func cpFile(src, dst string) error { var err error var srcfd *os.File @@ -323,4 +363,6 @@ func init() { runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.") runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") + runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") + runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") } From 95d725f2c19d49a5993cac8b81a6cc11b21dd573 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 17 Jul 2024 16:26:06 +0200 Subject: [PATCH 02/53] Wait on daemon down (#2279) --- client/cmd/down.go | 2 +- client/internal/engine.go | 36 ++++++++++- client/internal/networkmonitor/monitor_bsd.go | 16 ++++- client/server/server.go | 22 ++++++- management/client/grpc.go | 64 ++++++++++--------- signal/client/grpc.go | 35 ++++------ util/grpc/dialer.go | 43 +++++++++++++ 7 files changed, 159 insertions(+), 59 deletions(-) diff --git a/client/cmd/down.go b/client/cmd/down.go index 1837b13da..4d9f1eba4 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -26,7 +26,7 @@ var downCmd = &cobra.Command{ return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) defer cancel() conn, err := DialClientGRPCServer(ctx, daemonAddr) diff --git a/client/internal/engine.go b/client/internal/engine.go index 21a765a96..9e275c007 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -266,8 +266,23 @@ func (e *Engine) Stop() error { e.close() e.wgConnWorker.Wait() - log.Infof("stopped Netbird Engine") - return nil + + maxWaitTime := 5 * time.Second + timeout := time.After(maxWaitTime) + + for { + if !e.IsWGIfaceUp() { + log.Infof("stopped Netbird Engine") + return nil + } + + select { + case <-timeout: + return fmt.Errorf("timeout when waiting for interface shutdown") + default: + time.Sleep(100 * time.Millisecond) + } + } } // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services @@ -1533,3 +1548,20 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { return slices.Equal(checks.Files, oChecks.Files) }) } + +func (e *Engine) IsWGIfaceUp() bool { + if e == nil || e.wgInterface == nil { + return false + } + iface, err := net.InterfaceByName(e.wgInterface.Name()) + if err != nil { + log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err) + return false + } + + if iface.Flags&net.FlagUp != 0 { + return true + } + + return false +} diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go index 8d6ccd51b..29df7ea7f 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/monitor_bsd.go @@ -4,6 +4,7 @@ package networkmonitor import ( "context" + "errors" "fmt" "syscall" "unsafe" @@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca return fmt.Errorf("failed to open routing socket: %v", err) } defer func() { - if err := unix.Close(fd); err != nil { + err := unix.Close(fd) + if err != nil && !errors.Is(err, unix.EBADF) { log.Errorf("Network monitor: failed to close routing socket: %v", err) } }() + go func() { + <-ctx.Done() + err := unix.Close(fd) + if err != nil && !errors.Is(err, unix.EBADF) { + log.Debugf("Network monitor: closed routing socket") + } + }() + for { select { case <-ctx.Done(): @@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca buf := make([]byte, 2048) n, err := unix.Read(fd, buf) if err != nil { - log.Errorf("Network monitor: failed to read from routing socket: %v", err) + if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { + log.Errorf("Network monitor: failed to read from routing socket: %v", err) + } continue } if n < unix.SizeofRtMsghdr { diff --git a/client/server/server.go b/client/server/server.go index 2805c10f4..8173d0741 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes } // Down engine work in the daemon. -func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { +func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo state := internal.CtxGetState(s.rootCtx) state.Set(internal.StatusIdle) - return &proto.DownResponse{}, nil + maxWaitTime := 5 * time.Second + timeout := time.After(maxWaitTime) + + engine := s.connectClient.Engine() + + for { + if !engine.IsWGIfaceUp() { + return &proto.DownResponse{}, nil + } + + select { + case <-ctx.Done(): + return &proto.DownResponse{}, nil + case <-timeout: + return nil, fmt.Errorf("failed to shut down properly") + default: + time.Sleep(100 * time.Millisecond) + } + } } // Status returns the daemon status diff --git a/management/client/grpc.go b/management/client/grpc.go index a8f4a91c7..568c15313 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -2,7 +2,6 @@ package client import ( "context" - "crypto/tls" "fmt" "io" "sync" @@ -11,15 +10,11 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" - - "github.com/cenkalti/backoff/v4" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" @@ -51,26 +46,21 @@ type GrpcClient struct { // NewClient creates a new client to Management service func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { - transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) + var conn *grpc.ClientConn - if tlsEnabled { - transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + operation := func() error { + var err error + conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + if err != nil { + log.Printf("createConnection error: %v", err) + return err + } + return nil } - mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) - defer cancel() - conn, err := grpc.DialContext( - mgmCtx, - addr, - transportOption, - nbgrpc.WithCustomDialer(), - grpc.WithBlock(), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 30 * time.Second, - Timeout: 10 * time.Second, - })) + err := backoff.Retry(operation, nbgrpc.Backoff(ctx)) if err != nil { - log.Errorf("failed creating connection to Management Service %v", err) + log.Errorf("failed creating connection to Management Service: %v", err) return nil, err } @@ -326,25 +316,41 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro if !c.ready() { return nil, fmt.Errorf(errMsgNoMgmtConnection) } + loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) if err != nil { log.Errorf("failed to encrypt message: %s", err) return nil, err } - mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout) - defer cancel() - resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ - WgPubKey: c.key.PublicKey().String(), - Body: loginReq, - }) + + var resp *proto.EncryptedMessage + operation := func() error { + mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout) + defer cancel() + + var err error + resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: loginReq, + }) + if err != nil { + log.Printf("Login error: %v", err) + return err + } + + return nil + } + + err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx)) if err != nil { + log.Errorf("failed to login to Management Service: %v", err) return nil, err } loginResp := &proto.LoginResponse{} err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) if err != nil { - log.Errorf("failed to decrypt registration message: %s", err) + log.Errorf("failed to decrypt login response: %s", err) return nil, err } diff --git a/signal/client/grpc.go b/signal/client/grpc.go index c6f03ec86..7a3b502ff 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -2,7 +2,6 @@ package client import ( "context" - "crypto/tls" "fmt" "io" "sync" @@ -14,9 +13,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -64,28 +60,21 @@ func (c *GrpcClient) Close() error { // NewClient creates a new Signal client func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { + var conn *grpc.ClientConn - transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) - - if tlsEnabled { - transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + operation := func() error { + var err error + conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + if err != nil { + log.Printf("createConnection error: %v", err) + return err + } + return nil } - sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout) - defer cancel() - conn, err := grpc.DialContext( - sigCtx, - addr, - transportOption, - nbgrpc.WithCustomDialer(), - grpc.WithBlock(), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 30 * time.Second, - Timeout: 10 * time.Second, - })) - + err := backoff.Retry(operation, nbgrpc.Backoff(ctx)) if err != nil { - log.Errorf("failed to connect to the signalling server %v", err) + log.Errorf("failed to connect to the signalling server: %v", err) return nil, err } @@ -408,7 +397,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient, if err != nil { log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) - //todo send something?? + // todo send something?? } } } diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 3fba0c84e..57ab8fd55 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -2,12 +2,18 @@ package grpc import ( "context" + "crypto/tls" "net" "os/user" "runtime" + "time" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -35,3 +41,40 @@ func WithCustomDialer() grpc.DialOption { return conn, nil }) } + +// grpcDialBackoff is the backoff mechanism for the grpc calls +func Backoff(ctx context.Context) backoff.BackOff { + b := backoff.NewExponentialBackOff() + b.MaxElapsedTime = 10 * time.Second + b.Clock = backoff.SystemClock + return backoff.WithContext(b, ctx) +} + +func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { + transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) + + if tlsEnabled { + transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + } + + connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + conn, err := grpc.DialContext( + connCtx, + addr, + transportOption, + WithCustomDialer(), + grpc.WithBlock(), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 30 * time.Second, + Timeout: 10 * time.Second, + }), + ) + if err != nil { + log.Printf("DialContext error: %v", err) + return nil, err + } + + return conn, nil +} From e78ec2e985111dc3b1470c199f2093804a317abb Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 17 Jul 2024 19:50:06 +0200 Subject: [PATCH 03/53] Don't add exclusion routes for IPs that are part of connected networks (#2258) This prevents arp/ndp issues on macOS leading to unreachability of that IP. --- .../systemops/systemops_generic.go | 37 ++++++++- .../systemops/systemops_windows_test.go | 75 ++++++------------- 2 files changed, 59 insertions(+), 53 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 0d1c16ca1..615e1b528 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -50,7 +50,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop) if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) { log.Tracef("Adding for prefix %s: %v", prefix, err) - // These errors are not critical but also we should not track and try to remove the routes either. + // These errors are not critical, but also we should not track and try to remove the routes either. return nexthop, refcounter.ErrIgnore } return nexthop, err @@ -135,6 +135,11 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac return Nexthop{}, vars.ErrRouteNotAllowed } + // Check if the prefix is part of any local subnets + if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal { + return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed) + } + // Determine the exit interface and next hop for the prefix, so we can add a specific route nexthop, err := GetNextHop(addr) if err != nil { @@ -167,6 +172,36 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac return exitNextHop, nil } +func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) { + localInterfaces, err := net.Interfaces() + if err != nil { + log.Errorf("Failed to get local interfaces: %v", err) + return false, nil + } + + for _, intf := range localInterfaces { + addrs, err := intf.Addrs() + if err != nil { + log.Errorf("Failed to get addresses for interface %s: %v", intf.Name, err) + continue + } + + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok { + log.Errorf("Failed to convert address to IPNet: %v", addr) + continue + } + + if ipnet.Contains(prefix.Addr().AsSlice()) { + return true, ipnet + } + } + } + + return false, nil +} + // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // in two /1 prefixes to avoid replacing the existing default route func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go index 9180ed58c..19b006017 100644 --- a/client/internal/routemanager/systemops/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -73,7 +73,7 @@ var testCases = []testCase{ { name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence destination: "10.0.0.2:53", - expectedSourceIP: "10.0.0.1", + expectedSourceIP: "127.0.0.1", expectedDestPrefix: "10.0.0.0/8", expectedNextHop: "0.0.0.0", expectedInterface: "Loopback Pseudo-Interface 1", @@ -110,7 +110,7 @@ var testCases = []testCase{ { name: "To more specific route (local) without custom dialer via physical interface", destination: "127.0.10.2:53", - expectedSourceIP: "10.0.0.1", + expectedSourceIP: "127.0.0.1", expectedDestPrefix: "127.0.0.0/8", expectedNextHop: "0.0.0.0", expectedInterface: "Loopback Pseudo-Interface 1", @@ -181,31 +181,6 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut return combinedOutput } -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { - t.Helper() - - ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) - require.NoError(t, err) - subnetMaskSize, _ := ipNet.Mask.Size() - script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to assign IP address to loopback adapter") - - // Wait for the IP address to be applied - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - err = waitForIPAddress(ctx, interfaceName, ip.String()) - require.NoError(t, err, "IP address not applied within timeout") - - t.Cleanup(func() { - script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to remove IP address from loopback adapter") - }) - - return interfaceName -} - func fetchOriginalGateway() (*RouteInfo, error) { cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json") output, err := cmd.CombinedOutput() @@ -231,30 +206,6 @@ func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") } -func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() - if err != nil { - return err - } - - ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") - for _, ip := range ipAddresses { - if strings.TrimSpace(ip) == expectedIPAddress { - return nil - } - } - } - } -} - func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { var combined FindNetRouteOutput @@ -285,5 +236,25 @@ func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { func setupDummyInterfacesAndRoutes(t *testing.T) { t.Helper() - createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") + addDummyRoute(t, "10.0.0.0/8") +} + +func addDummyRoute(t *testing.T, dstCIDR string) { + t.Helper() + + script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR) + + output, err := exec.Command("powershell", "-Command", script).CombinedOutput() + if err != nil { + t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output) + t.FailNow() + } + + t.Cleanup(func() { + script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR) + output, err := exec.Command("powershell", "-Command", script).CombinedOutput() + if err != nil { + t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output) + } + }) } From 19147f518ead57a89ac4e544ccbded54521d19dc Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 17 Jul 2024 23:48:37 +0200 Subject: [PATCH 04/53] Add faster availability DNS probe and update test domain to .com (#2280) * Add faster availability DNS probe and update test domain to .com - Count success queries and compare it before doing after network map probes. - Reduce the first dns probe to 500ms - Updated test domain with com instead of . due to Palo alto DNS proxy server issues * use fqdn * Update client/internal/dns/upstream.go Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com> --------- Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com> --- client/internal/dns/upstream.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index b502bf5eb..b3baf2fa8 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -24,7 +24,7 @@ const ( probeTimeout = 2 * time.Second ) -const testRecord = "." +const testRecord = "com." type upstreamClient interface { exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) @@ -42,6 +42,7 @@ type upstreamResolverBase struct { upstreamServers []string disabled bool failsCount atomic.Int32 + successCount atomic.Int32 failsTillDeact int32 mutex sync.Mutex reactivatePeriod time.Duration @@ -124,6 +125,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } + u.successCount.Add(1) log.Tracef("took %s to query the upstream %s", t, upstream) err = w.WriteMsg(rm) @@ -172,6 +174,11 @@ func (u *upstreamResolverBase) probeAvailability() { default: } + // avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact + if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { + return + } + var success bool var mu sync.Mutex var wg sync.WaitGroup @@ -183,7 +190,7 @@ func (u *upstreamResolverBase) probeAvailability() { wg.Add(1) go func() { defer wg.Done() - err := u.testNameserver(upstream) + err := u.testNameserver(upstream, 500*time.Millisecond) if err != nil { errors = multierror.Append(errors, err) log.Warnf("probing upstream nameserver %s: %s", upstream, err) @@ -224,7 +231,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { } for _, upstream := range u.upstreamServers { - if err := u.testNameserver(upstream); err != nil { + if err := u.testNameserver(upstream, probeTimeout); err != nil { log.Tracef("upstream check for %s: %s", upstream, err) } else { // at least one upstream server is available, stop probing @@ -244,6 +251,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) u.failsCount.Store(0) + u.successCount.Add(1) u.reactivate() u.disabled = false } @@ -265,13 +273,14 @@ func (u *upstreamResolverBase) disable(err error) { } log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) + u.successCount.Store(0) u.deactivate(err) u.disabled = true go u.waitUntilResponse() } -func (u *upstreamResolverBase) testNameserver(server string) error { - ctx, cancel := context.WithTimeout(u.ctx, probeTimeout) +func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(u.ctx, timeout) defer cancel() r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) From 9a6de52dd09fa2edc9ef54cc48ea3acc84f56033 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 17 Jul 2024 23:49:09 +0200 Subject: [PATCH 05/53] Check if route interface is a Microsoft ISATAP device (#2282) check if the nexthop interfaces are Microsoft ISATAP devices and ignore their suffixes when comparing them --- .../internal/networkmonitor/monitor_windows.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go index e24bdd066..f58802e4e 100644 --- a/client/internal/networkmonitor/monitor_windows.go +++ b/client/internal/networkmonitor/monitor_windows.go @@ -232,14 +232,20 @@ func stateFromInt(state uint8) string { } func compareIntf(a, b *net.Interface) int { - if a == nil && b == nil { + switch { + case a == nil && b == nil: return 0 - } - if a == nil { + case a == nil: return -1 - } - if b == nil { + case b == nil: return 1 + case isIsatapInterface(a.Name) && isIsatapInterface(b.Name): + return 0 + default: + return a.Index - b.Index } - return a.Index - b.Index +} + +func isIsatapInterface(name string) bool { + return strings.HasPrefix(strings.ToLower(name), "isatap") } From c900fa81bbcaca684e54ea346c6a34b0ef6095b3 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 18 Jul 2024 12:15:14 +0200 Subject: [PATCH 06/53] Remove copy functions from signal (#2285) remove migration function for wiretrustee directories to netbird --- signal/cmd/run.go | 207 +++++++++++++++------------------------------- 1 file changed, 67 insertions(+), 140 deletions(-) diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 8f75c1e04..cfc140acb 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -6,12 +6,8 @@ import ( "errors" "flag" "fmt" - "io" - "io/fs" "net" "net/http" - "os" - "path" "strings" "time" @@ -58,9 +54,15 @@ var ( }) runCmd = &cobra.Command{ - Use: "run", - Short: "start NetBird Signal Server daemon", + Use: "run", + Short: "start NetBird Signal Server daemon", + SilenceUsage: true, PreRun: func(cmd *cobra.Command, args []string) { + err := util.InitLog(logLevel, logFile) + if err != nil { + log.Fatalf("failed initializing log %v", err) + } + flag.Parse() // detect whether user specified a port @@ -83,40 +85,9 @@ var ( RunE: func(cmd *cobra.Command, args []string) error { flag.Parse() - err := util.InitLog(logLevel, logFile) + opts, certManager, err := getTLSConfigurations() if err != nil { - log.Fatalf("failed initializing log %v", err) - } - - if signalSSLDir == "" { - oldPath := "/var/lib/wiretrustee" - if migrateToNetbird(oldPath, defaultSignalSSLDir) { - if err := cpDir(oldPath, defaultSignalSSLDir); err != nil { - log.Fatal(err) - } - } - } - - var opts []grpc.ServerOption - var certManager *autocert.Manager - var tlsConfig *tls.Config - if signalLetsencryptDomain != "" { - certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) - if err != nil { - return err - } - transportCredentials := credentials.NewTLS(certManager.TLSConfig()) - opts = append(opts, grpc.Creds(transportCredentials)) - log.Infof("setting up TLS with LetsEncrypt.") - } else if signalCertFile != "" && signalCertKey != "" { - tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) - if err != nil { - log.Errorf("cannot load TLS credentials: %v", err) - return err - } - transportCredentials := credentials.NewTLS(tlsConfig) - opts = append(opts, grpc.Creds(transportCredentials)) - log.Infof("setting up TLS with custom certificates.") + return err } metricsServer := metrics.NewServer(metricsPort, "") @@ -141,24 +112,15 @@ var ( proto.RegisterSignalExchangeServer(grpcServer, srv) grpcRootHandler := grpcHandlerFunc(grpcServer) + + if certManager != nil { + startServerWithCertManager(certManager, grpcRootHandler) + } + var compatListener net.Listener var grpcListener net.Listener var httpListener net.Listener - if certManager != nil { - // a call to certManager.Listener() always creates a new listener so we do it once - httpListener := certManager.Listener() - if signalPort == 443 { - // running gRPC and HTTP cert manager on the same port - serveHTTP(httpListener, certManager.HTTPHandler(grpcRootHandler)) - log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String()) - } else { - // Start the HTTP cert manager server separately - serveHTTP(httpListener, certManager.HTTPHandler(nil)) - log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String()) - } - } - // If certManager is configured and signalPort == 443, then the gRPC server has already been started if certManager == nil || signalPort != 443 { grpcListener, err = serveGRPC(grpcServer, signalPort) @@ -211,6 +173,58 @@ var ( } ) +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { + var ( + err error + certManager *autocert.Manager + tlsConfig *tls.Config + ) + + if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { + log.Infof("running without TLS") + return nil, nil, nil + } + + if signalLetsencryptDomain != "" { + certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) + if err != nil { + return nil, certManager, err + } + tlsConfig = certManager.TLSConfig() + log.Infof("setting up TLS with LetsEncrypt.") + } else { + if signalCertFile == "" || signalCertKey == "" { + log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + } + + tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) + if err != nil { + log.Errorf("cannot load TLS credentials: %v", err) + return nil, certManager, err + } + log.Infof("setting up TLS with custom certificates.") + } + + transportCredentials := credentials.NewTLS(tlsConfig) + + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err +} + +func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) { + // a call to certManager.Listener() always creates a new listener so we do it once + httpListener := certManager.Listener() + if signalPort == 443 { + // running gRPC and HTTP cert manager on the same port + serveHTTP(httpListener, certManager.HTTPHandler(grpcRootHandler)) + log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String()) + } else { + // Start the HTTP cert manager server separately + serveHTTP(httpListener, certManager.HTTPHandler(nil)) + log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String()) + } +} + func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") || @@ -272,93 +286,6 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) { return config, nil } -func cpFile(src, dst string) error { - var err error - var srcfd *os.File - var dstfd *os.File - var srcinfo os.FileInfo - - if srcfd, err = os.Open(src); err != nil { - return err - } - defer srcfd.Close() - - if dstfd, err = os.Create(dst); err != nil { - return err - } - defer dstfd.Close() - - if _, err = io.Copy(dstfd, srcfd); err != nil { - return err - } - if srcinfo, err = os.Stat(src); err != nil { - return err - } - return os.Chmod(dst, srcinfo.Mode()) -} - -func copySymLink(source, dest string) error { - link, err := os.Readlink(source) - if err != nil { - return err - } - return os.Symlink(link, dest) -} - -func cpDir(src string, dst string) error { - var err error - var fds []os.DirEntry - var srcinfo os.FileInfo - - if srcinfo, err = os.Stat(src); err != nil { - return err - } - - if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil { - return err - } - - if fds, err = os.ReadDir(src); err != nil { - return err - } - for _, fd := range fds { - srcfp := path.Join(src, fd.Name()) - dstfp := path.Join(dst, fd.Name()) - - fileInfo, err := os.Stat(srcfp) - if err != nil { - log.Fatalf("Couldn't get fileInfo; %v", err) - } - - switch fileInfo.Mode() & os.ModeType { - case os.ModeSymlink: - if err = copySymLink(srcfp, dstfp); err != nil { - log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err) - } - case os.ModeDir: - if err = cpDir(srcfp, dstfp); err != nil { - log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err) - } - default: - if err = cpFile(srcfp, dstfp); err != nil { - log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err) - } - } - } - return nil -} - -func migrateToNetbird(oldPath, newPath string) bool { - _, errOld := os.Stat(oldPath) - _, errNew := os.Stat(newPath) - - if errors.Is(errOld, fs.ErrNotExist) || errNew == nil { - return false - } - - return true -} - func init() { runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.") From ef1a39cb01d1231d76e503870360479f7addb957 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Thu, 18 Jul 2024 08:39:41 -0600 Subject: [PATCH 07/53] Refactor macOS system DNS configuration (#2284) On macOS use the recommended settings for providing split DNS. As per the docs an empty string will force the configuration to be the default. In order to to support split DNS an additional service config is added for the local server and search domain settings. see: https://developer.apple.com/documentation/devicemanagement/vpn/dns --- client/internal/dns/host.go | 6 ++ client/internal/dns/host_darwin.go | 161 +++++++++++++++++++---------- 2 files changed, 115 insertions(+), 52 deletions(-) diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index 3070763a6..e55a07055 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -15,6 +15,12 @@ type hostManager interface { restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error } +type SystemDNSSettings struct { + Domains []string + ServerIP string + ServerPort int +} + type HostDNSConfig struct { Domains []DomainConfig `json:"domains"` RouteAll bool `json:"routeAll"` diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 5ae84fb91..70d2443ef 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -7,6 +7,7 @@ import ( "bytes" "fmt" "io" + "net" "net/netip" "os/exec" "strconv" @@ -18,7 +19,7 @@ import ( const ( netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" globalIPv4State = "State:/Network/Global/IPv4" - primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS" + primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS" keySupplementalMatchDomains = "SupplementalMatchDomains" keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch" keyServerAddresses = "ServerAddresses" @@ -28,12 +29,12 @@ const ( scutilPath = "/usr/sbin/scutil" searchSuffix = "Search" matchSuffix = "Match" + localSuffix = "Local" ) type systemConfigurator struct { - // primaryServiceID primary interface in the system. AKA the interface with the default route - primaryServiceID string - createdKeys map[string]struct{} + createdKeys map[string]struct{} + systemDNSSettings SystemDNSSettings } func newHostManager() (hostManager, error) { @@ -49,20 +50,6 @@ func (s *systemConfigurator) supportCustomPort() bool { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { var err error - if config.RouteAll { - err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort) - if err != nil { - return fmt.Errorf("add dns setup for all: %w", err) - } - } else if s.primaryServiceID != "" { - err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) - if err != nil { - return fmt.Errorf("remote key from system config: %w", err) - } - s.primaryServiceID = "" - log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort) - } - // create a file for unclean shutdown detection if err := createUncleanShutdownIndicator(); err != nil { log.Errorf("failed to create unclean shutdown file: %s", err) @@ -73,6 +60,19 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { matchDomains []string ) + err = s.recordSystemDNSSettings(true) + if err != nil { + log.Errorf("unable to update record of System's DNS config: %s", err.Error()) + } + + if config.RouteAll { + searchDomains = append(searchDomains, "\"\"") + err = s.addLocalDNS() + if err != nil { + log.Infof("failed to enable split DNS") + } + } + for _, dConf := range config.Domains { if dConf.Disabled { continue @@ -119,10 +119,6 @@ func (s *systemConfigurator) restoreHostDNS() error { } log.Infof("removing %s domains from system", keyType) } - if s.primaryServiceID != "" { - lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) - log.Infof("restoring DNS resolver configuration for system") - } _, err := runSystemConfigCommand(wrapCommand(lines)) if err != nil { log.Errorf("got an error while cleaning the system configuration: %s", err) @@ -148,6 +144,97 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { return nil } +func (s *systemConfigurator) addLocalDNS() error { + if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 { + err := s.recordSystemDNSSettings(true) + log.Errorf("Unable to get system DNS configuration") + return err + } + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 { + err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) + if err != nil { + return fmt.Errorf("couldn't add local network DNS conf: %w", err) + } + } else { + log.Info("Not enabling local DNS server") + } + + return nil +} + +func (s *systemConfigurator) recordSystemDNSSettings(force bool) error { + if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force { + return nil + } + + systemDNSSettings, err := s.getSystemDNSSettings() + if err != nil { + return fmt.Errorf("couldn't get current DNS config: %w", err) + } + s.systemDNSSettings = systemDNSSettings + + return nil +} + +func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { + primaryServiceKey, _, err := s.getPrimaryService() + if err != nil || primaryServiceKey == "" { + return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err) + } + dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey) + line := buildCommandLine("show", dnsServiceKey, "") + stdinCommands := wrapCommand(line) + + b, err := runSystemConfigCommand(stdinCommands) + if err != nil { + return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err) + } + + var dnsSettings SystemDNSSettings + inSearchDomainsArray := false + inServerAddressesArray := false + + scanner := bufio.NewScanner(bytes.NewReader(b)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + switch { + case strings.HasPrefix(line, "DomainName :"): + domainName := strings.TrimSpace(strings.Split(line, ":")[1]) + dnsSettings.Domains = append(dnsSettings.Domains, domainName) + case line == "SearchDomains : {": + inSearchDomainsArray = true + continue + case line == "ServerAddresses : {": + inServerAddressesArray = true + continue + case line == "}": + inSearchDomainsArray = false + inServerAddressesArray = false + } + + if inSearchDomainsArray { + searchDomains := strings.Split(line, " : ") + dnsSettings.Domains = append(dnsSettings.Domains, searchDomains...) + } else if inServerAddressesArray { + address := strings.Split(line, " : ")[1] + if ip := net.ParseIP(address); ip != nil && ip.To4() != nil { + dnsSettings.ServerIP = address + inServerAddressesArray = false // Stop reading after finding the first IPv4 address + } + } + } + + if err := scanner.Err(); err != nil { + return dnsSettings, err + } + + // default to 53 port + dnsSettings.ServerPort = 53 + + return dnsSettings, nil +} + func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { err := s.addDNSState(key, domains, ip, port, true) if err != nil { @@ -194,23 +281,6 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port return nil } -func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error { - primaryServiceKey, existingNameserver, err := s.getPrimaryService() - if err != nil || primaryServiceKey == "" { - return fmt.Errorf("couldn't find the primary service key: %w", err) - } - - err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver) - if err != nil { - return fmt.Errorf("add dns setup: %w", err) - } - - log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port) - s.primaryServiceID = primaryServiceKey - - return nil -} - func (s *systemConfigurator) getPrimaryService() (string, string, error) { line := buildCommandLine("show", globalIPv4State, "") stdinCommands := wrapCommand(line) @@ -239,19 +309,6 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) { return primaryService, router, nil } -func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error { - lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0)) - lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer) - lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) - addDomainCommand := buildCreateStateWithOperation(setupKey, lines) - stdinCommands := wrapCommand(addDomainCommand) - _, err := runSystemConfigCommand(stdinCommands) - if err != nil { - return fmt.Errorf("applying dns setup, error: %w", err) - } - return nil -} - func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { if err := s.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via scutil: %w", err) From c815ad86fd7ccf2804995f8c956540f76c188535 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 18 Jul 2024 18:06:09 +0200 Subject: [PATCH 08/53] Fix macOS DNS unclean shutdown restore call on startup (#2286) previously, we called the restore method from the startup when there was an unclean shutdown. But it never had the state keys to clean since they are stored in memory this change addresses the issue by falling back to default values when restoring the host's DNS --- client/internal/dns/host_darwin.go | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 70d2443ef..b5f6f9295 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -110,19 +110,17 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { } func (s *systemConfigurator) restoreHostDNS() error { - lines := "" - for key := range s.createdKeys { - lines += buildRemoveKeyOperation(key) + keys := s.getRemovableKeysWithDefaults() + for _, key := range keys { keyType := "search" if strings.Contains(key, matchSuffix) { keyType = "match" } log.Infof("removing %s domains from system", keyType) - } - _, err := runSystemConfigCommand(wrapCommand(lines)) - if err != nil { - log.Errorf("got an error while cleaning the system configuration: %s", err) - return fmt.Errorf("clean system: %w", err) + err := s.removeKeyFromSystemConfig(key) + if err != nil { + log.Errorf("failed to remove %s domains from system: %s", keyType, err) + } } if err := removeUncleanShutdownIndicator(); err != nil { @@ -132,6 +130,19 @@ func (s *systemConfigurator) restoreHostDNS() error { return nil } +func (s *systemConfigurator) getRemovableKeysWithDefaults() []string { + if len(s.createdKeys) == 0 { + // return defaults for startup calls + return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)} + } + + keys := make([]string, 0, len(s.createdKeys)) + for key := range s.createdKeys { + keys = append(keys, key) + } + return keys +} + func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { line := buildRemoveKeyOperation(key) _, err := runSystemConfigCommand(wrapCommand(line)) From 0a8c78deb142bf69544bc3206b7721f23f6cd69c Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Fri, 19 Jul 2024 08:44:12 -0600 Subject: [PATCH 09/53] Minor fix local dns search domain (#2287) --- client/internal/dns/host_darwin.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b5f6f9295..5dee305c2 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -225,8 +225,8 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { } if inSearchDomainsArray { - searchDomains := strings.Split(line, " : ") - dnsSettings.Domains = append(dnsSettings.Domains, searchDomains...) + searchDomain := strings.Split(line, " : ")[1] + dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) } else if inServerAddressesArray { address := strings.Split(line, " : ")[1] if ip := net.ParseIP(address); ip != nil && ip.To4() != nil { From 926e11b086b9b2753d0b28dee888af17c3026638 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 22 Jul 2024 15:35:17 +0200 Subject: [PATCH 10/53] Remove default allow for UDP on unmatched packet (#2300) This fixes an issue where UDP rules were ineffective for userspace clients (Windows/macOS) --- client/firewall/uspfilter/uspfilter.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 427a73825..75792e9c0 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -337,7 +337,6 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { return rule.drop, true } - return rule.drop, true case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: return rule.drop, true } From 788f1309415fbd819656f1355f11f1a8dd53919c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 22 Jul 2024 15:49:25 +0200 Subject: [PATCH 11/53] Retry management connection only on context canceled (#2301) --- management/client/grpc.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/management/client/grpc.go b/management/client/grpc.go index 568c15313..eaadcd317 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -334,8 +334,11 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro Body: loginReq, }) if err != nil { - log.Printf("Login error: %v", err) - return err + // retry only on context canceled + if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled { + return err + } + return backoff.Permanent(err) } return nil From 268e801ec565e1dd3d4002f725bd5c6b5fc765d6 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 22 Jul 2024 19:44:15 +0200 Subject: [PATCH 12/53] Ignore network monitor checks for software interfaces (#2302) ignore checks for Teredo and ISATAP interfaces --- client/internal/networkmonitor/monitor_windows.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go index f58802e4e..308b2aa45 100644 --- a/client/internal/networkmonitor/monitor_windows.go +++ b/client/internal/networkmonitor/monitor_windows.go @@ -99,6 +99,11 @@ func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []syste return false } + if isSoftInterface(nexthop.Intf.Name) { + log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name) + return false + } + unspec := getUnspecifiedPrefix(nexthop.IP) defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec) @@ -119,7 +124,7 @@ func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix { return netip.PrefixFrom(netip.IPv4Unspecified(), 0) } -func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) { +func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) { var defaultRoutes []string foundMatchingRoute := false @@ -128,7 +133,7 @@ func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []syst routeInfo := formatRouteInfo(r) defaultRoutes = append(defaultRoutes, routeInfo) - if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 { + if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 { foundMatchingRoute = true log.Debugf("network monitor: found matching default route: %s", routeInfo) } @@ -239,13 +244,11 @@ func compareIntf(a, b *net.Interface) int { return -1 case b == nil: return 1 - case isIsatapInterface(a.Name) && isIsatapInterface(b.Name): - return 0 default: return a.Index - b.Index } } -func isIsatapInterface(name string) bool { - return strings.HasPrefix(strings.ToLower(name), "isatap") +func isSoftInterface(name string) bool { + return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo") } From 63aeeb834dc9a7834f9aab2ff29bba11b060b9c7 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 24 Jul 2024 13:27:01 +0200 Subject: [PATCH 13/53] Fix error handling (#2316) --- signal/cmd/run.go | 2 +- signal/metrics/metrics.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/signal/cmd/run.go b/signal/cmd/run.go index cfc140acb..61f7a32a7 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -90,7 +90,7 @@ var ( return err } - metricsServer := metrics.NewServer(metricsPort, "") + metricsServer, err := metrics.NewServer(metricsPort, "") if err != nil { return fmt.Errorf("setup metrics: %v", err) } diff --git a/signal/metrics/metrics.go b/signal/metrics/metrics.go index 30db1600a..f411501cb 100644 --- a/signal/metrics/metrics.go +++ b/signal/metrics/metrics.go @@ -26,10 +26,10 @@ type Metrics struct { } // NewServer initializes and returns a new Metrics instance -func NewServer(port int, endpoint string) *Metrics { +func NewServer(port int, endpoint string) (*Metrics, error) { exporter, err := prometheus.New() if err != nil { - return nil + return nil, err } provider := metric.NewMeterProvider(metric.WithReader(exporter)) @@ -57,7 +57,7 @@ func NewServer(port int, endpoint string) *Metrics { provider: provider, Endpoint: endpoint, Server: server, - } + }, nil } // Shutdown stops the metrics server From 45fd1e9c21dc520bf66c83bb4d52ba4ccbe656c9 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 25 Jul 2024 16:22:04 +0200 Subject: [PATCH 14/53] add save peer status test for connected peers (#2321) --- management/server/sql_store_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index f46ca7e5d..7f48810d7 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -402,6 +402,17 @@ func TestSqlite_SavePeerStatus(t *testing.T) { actual := account.Peers["testpeer"].Status assert.Equal(t, newStatus, *actual) + + newStatus.Connected = true + + err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + require.NoError(t, err) + + account, err = store.GetAccount(context.Background(), account.Id) + require.NoError(t, err) + + actual = account.Peers["testpeer"].Status + assert.Equal(t, newStatus, *actual) } func TestSqlite_SavePeerLocation(t *testing.T) { if runtime.GOOS == "windows" { From 1f48fdf6cadca8576bca578dca829e9672cb98ce Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 26 Jul 2024 07:49:05 +0200 Subject: [PATCH 15/53] Add SavePeer method to prevent a possible account inconsistency (#2296) SyncPeer was storing the account with a simple read lock This change introduces the SavePeer method to the store to be used in these cases --- management/server/file_store.go | 20 ++++++++++++ management/server/peer.go | 5 +-- management/server/sql_store.go | 48 +++++++++++++++++++++++----- management/server/sql_store_test.go | 49 +++++++++++++++++++++++++++++ management/server/store.go | 4 ++- 5 files changed, 116 insertions(+), 10 deletions(-) diff --git a/management/server/file_store.go b/management/server/file_store.go index c649602e2..9a1462832 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -666,6 +666,26 @@ func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error { return s.persist(ctx, s.storeFile) } +// SavePeer saves the peer in the account +func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + newPeer := peer.Copy() + + account.Peers[peer.ID] = newPeer + + s.PeerKeyID2AccountID[peer.Key] = accountID + s.PeerID2AccountID[peer.ID] = accountID + + return nil +} + // SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. // PeerStatus will be saved eventually when some other changes occur. func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { diff --git a/management/server/peer.go b/management/server/peer.go index b8605fbb7..ec8b773b0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -7,10 +7,11 @@ import ( "strings" "time" - "github.com/netbirdio/netbird/management/server/posture" "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -539,7 +540,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac peer, updated := updatePeerMeta(peer, sync.Meta, account) if updated { - err = am.Store.SaveAccount(ctx, account) + err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { return nil, nil, nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 37cc10d8b..7648538c3 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -31,8 +31,10 @@ import ( ) const ( - storeSqliteFileName = "store.db" - idQueryCondition = "id = ?" + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" + accountAndIDQueryCondition = "account_id = ? and id = ?" + peerNotFoundFMT = "peer %s not found" ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -271,6 +273,38 @@ func (s *SqlStore) GetInstallationID() string { return installation.InstallationIDValue } +func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { + // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. + peerCopy := peer.Copy() + peerCopy.AccountID = accountID + + err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // check if peer exists before saving + var peerID string + result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) + if result.Error != nil { + return result.Error + } + + if peerID == "" { + return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID) + } + + result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy) + if result.Error != nil { + return result.Error + } + + return nil + }) + + if err != nil { + return err + } + + return nil +} + func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -281,14 +315,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe } result := s.db.Model(&nbpeer.Peer{}). Select(fieldsToUpdate). - Where("account_id = ? AND id = ?", accountID, peerID). + Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "peer %s not found", peerID) + return status.Errorf(status.NotFound, peerNotFoundFMT, peerID) } return nil @@ -302,7 +336,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P peerCopy.Location = peerWithLocation.Location result := s.db.Model(&nbpeer.Peer{}). - Where("account_id = ? and id = ?", accountID, peerWithLocation.ID). + Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID). Updates(peerCopy) if result.Error != nil { @@ -310,7 +344,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID) + return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID) } return nil @@ -644,7 +678,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*S func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { var user User - result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID) + result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "user %s not found", userID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 7f48810d7..ce4ee531a 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -362,6 +362,54 @@ func TestSqlite_GetAccount(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } +func TestSqlite_SavePeer(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/store.json") + + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") + require.NoError(t, err) + + // save status of non-existing peer + peer := &nbpeer.Peer{ + Key: "peerkey", + ID: "testpeer", + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } + ctx := context.Background() + err = store.SavePeer(ctx, account.Id, peer) + assert.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + + // save new status of existing peer + account.Peers[peer.ID] = peer + + err = store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + updatedPeer := peer.Copy() + updatedPeer.Status.Connected = false + updatedPeer.Meta.Hostname = "updatedpeer" + + err = store.SavePeer(ctx, account.Id, updatedPeer) + require.NoError(t, err) + + account, err = store.GetAccount(context.Background(), account.Id) + require.NoError(t, err) + + actual := account.Peers[peer.ID] + assert.Equal(t, updatedPeer.Status, actual.Status) + assert.Equal(t, updatedPeer.Meta, actual.Meta) +} + func TestSqlite_SavePeerStatus(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") @@ -414,6 +462,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { actual = account.Peers["testpeer"].Status assert.Equal(t, newStatus, *actual) } + func TestSqlite_SavePeerLocation(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") diff --git a/management/server/store.go b/management/server/store.go index 3ba73e8c7..15a419c78 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,10 +12,11 @@ import ( "strings" "time" - nbgroup "github.com/netbirdio/netbird/management/server/group" log "github.com/sirupsen/logrus" "gorm.io/gorm" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" @@ -54,6 +55,7 @@ type Store interface { AcquireAccountReadLock(ctx context.Context, accountID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock(ctx context.Context) func() + SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error From 1a15b0f900e6e2dfc49c937b443675f2b95de74d Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 26 Jul 2024 16:27:51 +0200 Subject: [PATCH 16/53] Fix race issue in set listener (#2332) --- client/internal/dns/service_listener.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 89cf4daf6..e0f9da26f 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -128,6 +128,9 @@ func (s *serviceViaListener) RuntimeIP() string { } func (s *serviceViaListener) setListenerStatus(running bool) { + s.listenerFlagLock.Lock() + defer s.listenerFlagLock.Unlock() + s.listenerIsRunning = running } From ea3205643a1de32522019757b848d4689664fabc Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 26 Jul 2024 16:33:20 +0200 Subject: [PATCH 17/53] Save daemon address on service install (#2328) --- client/cmd/service_installer.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 5e147262b..99a4821b0 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -31,6 +31,8 @@ var installCmd = &cobra.Command{ configPath, "--log-level", logLevel, + "--daemon-addr", + daemonAddr, } if managementURL != "" { From 7321046cd6a9a16c42eceeea98e4b60b94414baa Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 26 Jul 2024 17:33:54 +0300 Subject: [PATCH 18/53] Remove redundant check for empty JWT groups (#2323) * Remove redundant check for empty group names in SetJWTGroups * add test --- management/server/account.go | 4 ---- management/server/account_test.go | 7 +++++++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 558de6fbb..16a2851bd 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -770,10 +770,6 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { // SetJWTGroups updates the user's auto groups by synchronizing JWT groups. // Returns true if there are changes in the JWT group membership. func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { - if len(groupsNames) == 0 { - return false - } - user, ok := a.Users[userID] if !ok { return false diff --git a/management/server/account_test.go b/management/server/account_test.go index 71b43bd65..45b4fbd6f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2219,6 +2219,13 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") }) + + t.Run("remove all JWT groups", func(t *testing.T) { + updated := account.SetJWTGroups("user1", []string{}) + assert.True(t, updated, "account should be updated") + assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain") + assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present") + }) } func TestAccount_UserGroupsAddToPeers(t *testing.T) { From da39c8bbca24785925a7a0bd0ec21ca796475ec1 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 29 Jul 2024 13:30:27 +0200 Subject: [PATCH 19/53] Refactor login with store.SavePeer (#2334) This pull request refactors the login functionality by integrating store.SavePeer. The changes aim to improve the handling of peer login processes, particularly focusing on synchronization and error handling. Changes: - Refactored login logic to use store.SavePeer. - Added checks for login without lock for login necessary checks from the client and utilized write lock for full login flow. - Updated error handling with status.NewPeerLoginExpiredError(). - Moved geoIP check logic to a more appropriate place. - Removed redundant calls and improved documentation. - Moved the code to smaller methods to improve readability. --- management/server/peer.go | 199 ++++++++++++++---------------- management/server/status/error.go | 5 + 2 files changed, 100 insertions(+), 104 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index ec8b773b0..a740067a8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -453,6 +453,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s Location: peer.Location, } + if am.geo != nil && newPeer.Location.ConnectionIP != nil { + location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) + if err != nil { + log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) + } else { + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + } + } + // add peer to 'All' group group, err := account.GetGroupAll() if err != nil { @@ -535,7 +546,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if peerLoginExpired(ctx, peer, account.Settings) { - return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") + return nil, nil, nil, status.NewPeerLoginExpiredError() } peer, updated := updatePeerMeta(peer, sync.Meta, account) @@ -586,21 +597,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. newPeer := &nbpeer.Peer{ - Key: login.WireGuardPubKey, - Meta: login.Meta, - SSHKey: login.SSHKey, - } - if am.geo != nil && login.ConnectionIP != nil { - location, err := am.geo.Lookup(login.ConnectionIP) - if err != nil { - log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err) - } else { - newPeer.Location.ConnectionIP = login.ConnectionIP - newPeer.Location.CountryCode = location.Country.ISOCode - newPeer.Location.CityName = location.City.Names.En - newPeer.Location.GeoNameID = location.City.GeonameID - - } + Key: login.WireGuardPubKey, + Meta: login.Meta, + SSHKey: login.SSHKey, + Location: nbpeer.Location{ConnectionIP: login.ConnectionIP}, } return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) @@ -610,44 +610,17 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") } - peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) - if err != nil { - return nil, nil, nil, status.NewPeerNotRegisteredError() - } - - accSettings, err := am.Store.GetAccountSettings(ctx, accountID) - if err != nil { - return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err) - } - - var isWriteLock bool - - // duplicated logic from after the lock to have an early exit - expired := peerLoginExpired(ctx, peer, accSettings) - switch { - case expired: - if err := checkAuth(ctx, login.UserID, peer); err != nil { + // when the client sends a login request with a JWT which is used to get the user ID, + // it means that the client has already checked if it needs login and had been through the SSO flow + // so, we can skip this check and directly proceed with the login + if login.UserID == "" { + err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login) + if err != nil { return nil, nil, nil, err } - isWriteLock = true - log.WithContext(ctx).Debugf("peer login expired, acquiring write lock") - - case peer.UpdateMetaIfNew(login.Meta): - isWriteLock = true - log.WithContext(ctx).Debugf("peer changed meta, acquiring write lock") - - default: - isWriteLock = false - log.WithContext(ctx).Debugf("peer meta is the same, acquiring read lock") } - var unlock func() - - if isWriteLock { - unlock = am.Store.AcquireAccountWriteLock(ctx, accountID) - } else { - unlock = am.Store.AcquireAccountReadLock(ctx, accountID) - } + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer func() { if unlock != nil { unlock() @@ -660,7 +633,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) return nil, nil, nil, err } - peer, err = account.FindPeerByPubKey(login.WireGuardPubKey) + peer, err := account.FindPeerByPubKey(login.WireGuardPubKey) if err != nil { return nil, nil, nil, status.NewPeerNotRegisteredError() } @@ -671,53 +644,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } // this flag prevents unnecessary calls to the persistent store. - shouldStoreAccount := false + shouldStorePeer := false updateRemotePeers := false if peerLoginExpired(ctx, peer, account.Settings) { - err = checkAuth(ctx, login.UserID, peer) + err = am.handleExpiredPeer(ctx, login, account, peer) if err != nil { return nil, nil, nil, err } - // If peer was expired before and if it reached this point, it is re-authenticated. - // UserID is present, meaning that JWT validation passed successfully in the API layer. - updatePeerLastLogin(peer, account) updateRemotePeers = true - shouldStoreAccount = true - - // sync user last login with peer last login - user, err := account.FindUser(login.UserID) - if err != nil { - return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") - } - user.updateLastLogin(peer.LastLogin) - - am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) + shouldStorePeer = true } isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, nil, nil, err } + peer, updated := updatePeerMeta(peer, login.Meta, account) if updated { - shouldStoreAccount = true + shouldStorePeer = true } - peer, err = am.checkAndUpdatePeerSSHKey(ctx, peer, account, login.SSHKey) - if err != nil { - return nil, nil, nil, err + if peer.SSHKey != login.SSHKey { + peer.SSHKey = login.SSHKey + shouldStorePeer = true } - if shouldStoreAccount { - if !isWriteLock { - log.WithContext(ctx).Errorf("account %s should be stored but is not write locked", accountID) - return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked") - } - err = am.Store.SaveAccount(ctx, account) + if shouldStorePeer { + err = am.Store.SavePeer(ctx, accountID, peer) if err != nil { return nil, nil, nil, err } } + unlock() unlock = nil @@ -725,13 +684,46 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) am.updateAccountPeers(ctx, account) } + return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) +} + +// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO +// and if the peer login is expired. +// The NetBird client doesn't have a way to check if the peer needs login besides sending a login request +// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired +// and before starting the engine, we do the checks without an account lock to avoid piling up requests. +func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) + if err != nil { + return err + } + + // if the peer was not added with SSO login we can exit early because peers activated with setup-key + // doesn't expire, and we avoid extra databases calls. + if !peer.AddedWithSSOLogin() { + return nil + } + + settings, err := am.Store.GetAccountSettings(ctx, accountID) + if err != nil { + return err + } + + if peerLoginExpired(ctx, peer, settings) { + return status.NewPeerLoginExpiredError() + } + + return nil +} + +func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { var postureChecks []*posture.Checks if isRequiresApproval { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } - return peer, emptyMap, postureChecks, nil + return peer, emptyMap, nil, nil } approvedPeersMap, err := am.GetValidatedPeers(account) @@ -743,6 +735,30 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil } +func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error { + err := checkAuth(ctx, login.UserID, peer) + if err != nil { + return err + } + // If peer was expired before and if it reached this point, it is re-authenticated. + // UserID is present, meaning that JWT validation passed successfully in the API layer. + updatePeerLastLogin(peer, account) + + // sync user last login with peer last login + user, err := account.FindUser(login.UserID) + if err != nil { + return status.Errorf(status.Internal, "couldn't find user") + } + + err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin) + if err != nil { + return err + } + + am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) + return nil +} + func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { if peer.AddedWithSSOLogin() { user, err := account.FindUser(peer.UserID) @@ -759,11 +775,11 @@ func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error { if loginUserID == "" { // absence of a user ID indicates that JWT wasn't provided. - return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") + return status.NewPeerLoginExpiredError() } if peer.UserID != loginUserID { log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) - return status.Errorf(status.Unauthenticated, "can't login") + return status.Errorf(status.Unauthenticated, "can't login with this credentials") } return nil } @@ -783,31 +799,6 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) { account.UpdatePeer(peer) } -func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(ctx context.Context, peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) { - if len(newSSHKey) == 0 { - log.WithContext(ctx).Debugf("no new SSH key provided for peer %s, skipping update", peer.ID) - return peer, nil - } - - if peer.SSHKey == newSSHKey { - log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peer.ID) - return peer, nil - } - - peer.SSHKey = newSSHKey - account.UpdatePeer(peer) - - err := am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - - // trigger network map update - am.updateAccountPeers(ctx, account) - - return peer, nil -} - // UpdatePeerSSHKey updates peer's public SSH key func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { if sshKey == "" { diff --git a/management/server/status/error.go b/management/server/status/error.go index 39cd6c613..58b9a84a0 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -95,3 +95,8 @@ func NewUserNotFoundError(userKey string) error { func NewPeerNotRegisteredError() error { return Errorf(Unauthenticated, "peer is not registered") } + +// NewPeerLoginExpiredError creates a new Error with PermissionDenied type for an expired peer +func NewPeerLoginExpiredError() error { + return Errorf(PermissionDenied, "peer login has expired, please log in once more") +} From 9d2047a08af6b68bbd65382ab5f276567f6ac994 Mon Sep 17 00:00:00 2001 From: Evgenii Date: Wed, 31 Jul 2024 08:58:04 +0100 Subject: [PATCH 20/53] Fix freebsd tests (#2346) --- .github/workflows/golang-test-freebsd.yml | 24 ++++++++++------------- client/ui/font_bsd.go | 2 +- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index 15fc6a729..6cb5f7427 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -13,7 +13,7 @@ concurrency: jobs: test: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - name: Test in FreeBSD @@ -21,19 +21,15 @@ jobs: uses: vmactions/freebsd-vm@v1 with: usesh: true + copyback: false + release: "14.1" prepare: | - pkg install -y curl - pkg install -y git + pkg install -y go + # -x - to print all executed commands + # -e - to faile on first error run: | - set -x - curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L - tar zxf go.tar.gz - mv go /usr/local/go - ln -s /usr/local/go/bin/go /usr/local/bin/go - go mod tidy - go test -timeout 5m -p 1 ./iface/... - go test -timeout 5m -p 1 ./client/... - cd client - go build . - cd .. \ No newline at end of file + set -e -x + go build -o netbird client/main.go + go test -timeout 5m -p 1 -failfast ./iface/... + go test -timeout 5m -p 1 -failfast ./client/... diff --git a/client/ui/font_bsd.go b/client/ui/font_bsd.go index 41bccceca..84cb5993d 100644 --- a/client/ui/font_bsd.go +++ b/client/ui/font_bsd.go @@ -1,4 +1,4 @@ -//go:build darwin || dragonfly || freebsd || netbsd || openbsd +//go:build darwin package main From 165988429c5406bd721f0813c26977fafc41862a Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 31 Jul 2024 14:53:32 +0200 Subject: [PATCH 21/53] Add write lock for peer when saving its connection status (#2359) --- management/server/account.go | 28 +++++++++++++---------- management/server/dns.go | 4 ++-- management/server/event.go | 2 +- management/server/file_store.go | 20 ++++++++-------- management/server/group.go | 16 ++++++------- management/server/integrated_validator.go | 2 +- management/server/nameserver.go | 10 ++++---- management/server/peer.go | 14 ++++++------ management/server/policy.go | 8 +++---- management/server/posture_checks.go | 8 +++---- management/server/route.go | 10 ++++---- management/server/setupkey.go | 8 +++---- management/server/sql_store.go | 20 ++++++++-------- management/server/store.go | 8 +++---- management/server/user.go | 22 +++++++++--------- 15 files changed, 93 insertions(+), 87 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 16a2851bd..b2b23dcb9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -974,7 +974,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -1025,7 +1025,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -1124,7 +1124,7 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { @@ -1584,7 +1584,7 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string return err } - unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) + unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) defer unlock() account, err = am.Store.GetAccountByUser(ctx, user.Id) @@ -1667,7 +1667,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims if err != nil { return nil, nil, err } - unlock := am.Store.AcquireAccountWriteLock(ctx, newAcc.Id) + unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id) alreadyUnlocked := false defer func() { if !alreadyUnlocked { @@ -1823,7 +1823,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C account, err := am.Store.GetAccountByUser(ctx, claims.UserId) if err == nil { - unlockAccount := am.Store.AcquireAccountWriteLock(ctx, account.Id) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id) defer unlockAccount() account, err = am.Store.GetAccountByUser(ctx, claims.UserId) if err != nil { @@ -1843,7 +1843,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C return account, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if domainAccount != nil { - unlockAccount := am.Store.AcquireAccountWriteLock(ctx, domainAccount.Id) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id) defer unlockAccount() domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { @@ -1866,8 +1866,10 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey return nil, nil, nil, err } - unlock := am.Store.AcquireAccountReadLock(ctx, accountID) - defer unlock() + accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) + defer accountUnlock() + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { @@ -1896,8 +1898,10 @@ func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *n return err } - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) - defer unlock() + accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) + defer accountUnlock() + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peer.Key) + defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { @@ -1919,7 +1923,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return err } - unlock := am.Store.AcquireAccountReadLock(ctx, accountID) + unlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/dns.go b/management/server/dns.go index 8a889df3f..08732ad78 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -36,7 +36,7 @@ func (d DNSSettings) Copy() DNSSettings { // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s // SaveDNSSettings validates a user role and updates the account's DNS settings func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/event.go b/management/server/event.go index 616cea287..93b809226 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -13,7 +13,7 @@ import ( // GetEvents returns a list of activity events of an account func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/file_store.go b/management/server/file_store.go index 9a1462832..6e3536bcd 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -39,8 +39,8 @@ type FileStore struct { mux sync.Mutex `json:"-"` storeFile string `json:"-"` - // sync.Mutex indexed by accountID - accountLocks sync.Map `json:"-"` + // sync.Mutex indexed by resource ID + resourceLocks sync.Map `json:"-"` globalAccountLock sync.Mutex `json:"-"` metrics telemetry.AppMetrics `json:"-"` @@ -281,26 +281,26 @@ func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { return unlock } -// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock -func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) { - log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID) +// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock +func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID) start := time.Now() - value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) + value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{}) mtx := value.(*sync.Mutex) mtx.Lock() unlock = func() { mtx.Unlock() - log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start)) + log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start)) } return unlock } -// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock +// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock // This method is still returns a write lock as file store can't handle read locks -func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) { - return s.AcquireAccountWriteLock(ctx, accountID) +func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + return s.AcquireWriteLockByUID(ctx, uniqueID) } func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error { diff --git a/management/server/group.go b/management/server/group.go index 45c51bda2..d7711c126 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -23,7 +23,7 @@ func (e *GroupLinkError) Error() string { // GetGroup object of the peers func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI // GetAllGroups returns all groups in an account func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -77,7 +77,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID str // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -110,7 +110,7 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, // SaveGroup object of the peers func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) } @@ -245,7 +245,7 @@ func difference(a, b []string) []string { // DeleteGroup object of the peers func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountId) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) defer unlock() account, err := am.Store.GetAccount(ctx, accountId) @@ -359,7 +359,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use // ListGroups objects of the peers func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -377,7 +377,7 @@ func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID strin // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -413,7 +413,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 05537ada4..99e6b204c 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -32,7 +32,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con return errors.New("invalid groups") } - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() a, err := am.Store.GetAccountByUser(ctx, userID) diff --git a/management/server/nameserver.go b/management/server/nameserver.go index f8d644ded..636f7cfee 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -20,7 +20,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -95,7 +95,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco // SaveNameServerGroup saves nameserver group func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() if nsGroupToSave == nil { @@ -130,7 +130,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/peer.go b/management/server/peer.go index a740067a8..998a9e53b 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -150,7 +150,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -272,7 +272,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou // DeletePeer removes peer from the account by its IP func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -356,7 +356,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") } - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer func() { if unlock != nil { unlock() @@ -380,7 +380,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. - // Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow) + // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow) // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. @@ -620,7 +620,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer func() { if unlock != nil { unlock() @@ -811,7 +811,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st return err } - unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) + unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) defer unlock() // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) @@ -846,7 +846,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/policy.go b/management/server/policy.go index a70d7f0ed..30614ed2d 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -315,7 +315,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, // GetPolicy from the store func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -343,7 +343,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic // SavePolicy in the store func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -371,7 +371,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user // DeletePolicy from the store func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -398,7 +398,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po // ListPolicies from the store func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 851d4d31f..4a7c9755d 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -15,7 +15,7 @@ const ( ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -42,7 +42,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID } func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -89,7 +89,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -121,7 +121,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun } func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/route.go b/management/server/route.go index 6db00a255..064f3c105 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -17,7 +17,7 @@ import ( // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -126,7 +126,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { // CreateRoute creates and saves a new route func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -214,7 +214,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri // SaveRoute saves route func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() if routeToSave == nil { @@ -283,7 +283,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -311,7 +311,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index dcaee357c..8ef91755c 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -210,7 +210,7 @@ func Hash(s string) uint32 { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() keyDuration := DefaultSetupKeyDuration @@ -256,7 +256,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() if keyToSave == nil { @@ -328,7 +328,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str // ListSetupKeys returns a list of all setup keys of the account func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { @@ -360,7 +360,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 7648538c3..c44ab7f09 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -40,7 +40,7 @@ const ( // SqlStore represents an account storage backed by a Sql DB persisted to disk type SqlStore struct { db *gorm.DB - accountLocks sync.Map + resourceLocks sync.Map globalAccountLock sync.Mutex metrics telemetry.AppMetrics installationPK int @@ -98,33 +98,35 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { return unlock } -func (s *SqlStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) { - log.WithContext(ctx).Tracef("acquiring write lock for account %s", accountID) +// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock +func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID) start := time.Now() - value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) + value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) mtx := value.(*sync.RWMutex) mtx.Lock() unlock = func() { mtx.Unlock() - log.WithContext(ctx).Tracef("released write lock for account %s in %v", accountID, time.Since(start)) + log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start)) } return unlock } -func (s *SqlStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) { - log.WithContext(ctx).Tracef("acquiring read lock for account %s", accountID) +// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock +func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID) start := time.Now() - value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) + value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) mtx := value.(*sync.RWMutex) mtx.RLock() unlock = func() { mtx.RUnlock() - log.WithContext(ctx).Tracef("released read lock for account %s in %v", accountID, time.Since(start)) + log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start)) } return unlock diff --git a/management/server/store.go b/management/server/store.go index 15a419c78..864871c8e 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -49,10 +49,10 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error - // AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock - AcquireAccountWriteLock(ctx context.Context, accountID string) func() - // AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock - AcquireAccountReadLock(ctx context.Context, accountID string) func() + // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock + AcquireWriteLockByUID(ctx context.Context, uniqueID string) func() + // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock + AcquireReadLockByUID(ctx context.Context, uniqueID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock(ctx context.Context) func() SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error diff --git a/management/server/user.go b/management/server/user.go index 65b5c7878..2cb6b45aa 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -211,7 +211,7 @@ func NewOwnerUser(id string) *User { // createServiceUser creates a new service user under the given account. func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -267,7 +267,7 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user // inviteNewUser Invites a USer to a given account and creates reference in datastore func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() if am.idpManager == nil { @@ -368,7 +368,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A return nil, fmt.Errorf("failed to get account with token claims %v", err) } - unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) + unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) defer unlock() account, err = am.Store.GetAccount(ctx, account.Id) @@ -401,7 +401,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -428,7 +428,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init if initiatorUserID == targetUserID { return status.Errorf(status.InvalidArgument, "self deletion is not allowed") } - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -538,7 +538,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() if am.idpManager == nil { @@ -578,7 +578,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin // CreatePAT creates a new PAT for the given user func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() if tokenName == "" { @@ -628,7 +628,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string // DeletePAT deletes a specific PAT from a user func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -678,7 +678,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string // GetPAT returns a specific PAT from a user func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -710,7 +710,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i // GetAllPATs returns all PATs for a user func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) @@ -752,7 +752,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists) From c832cef44c2c588c2008dc637fb369d8b4ed6720 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 31 Jul 2024 19:48:12 +0300 Subject: [PATCH 22/53] Update SaveUsers and SaveGroups to SaveAccount (#2362) Changed SaveUsers and SaveGroups method calls to SaveAccount for consistency in data persistence operations. --- management/server/group.go | 2 +- management/server/user.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index d7711c126..37a6fc305 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -163,7 +163,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } account.Network.IncSerial() - if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } diff --git a/management/server/user.go b/management/server/user.go index 2cb6b45aa..b8afcda3a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -859,7 +859,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } account.Network.IncSerial() - if err = am.Store.SaveUsers(account.Id, account.Users); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } From 5ee9c77e907d3dfbf1f5886d0cf4be37eb09b7dd Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 31 Jul 2024 21:51:45 +0200 Subject: [PATCH 23/53] Move write peer lock (#2364) Moved the write peer lock to avoid latency caused by disk access Updated the method CancelPeerRoutines to use the peer public key --- management/server/account.go | 24 ++++++++++++------- management/server/grpcserver.go | 2 +- management/server/mock_server/account_mock.go | 2 +- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index b2b23dcb9..4648c00cd 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -136,7 +136,7 @@ type AccountManager interface { GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(account *Account) (map[string]struct{}, error) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error + CancelPeerRoutines(ctx context.Context, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) @@ -1858,6 +1858,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + // acquiring peer write lock here is ok since we only modify peer information that is supplied by the + // peer itself which can't be modified by API, and it only happens after an account read lock is acquired + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { @@ -1868,8 +1873,6 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { @@ -1889,8 +1892,13 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey return peer, netMap, postureChecks, nil } -func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error { - accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key) +func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peerPubKey string) error { + // acquiring peer write lock here is ok since we only modify peer information that is supplied by the + // peer itself which can't be modified by API, and it only happens after an account read lock is acquired + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() + + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { return status.Errorf(status.Unauthenticated, "peer not registered") @@ -1900,17 +1908,15 @@ func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *n accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peer.Key) - defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account) + err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) if err != nil { - log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err) + log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } return nil diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 170e72dd0..4a12a5c3e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -236,7 +236,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer * func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) { s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID) - _ = s.accountManager.CancelPeerRoutines(ctx, peer) + _ = s.accountManager.CancelPeerRoutines(ctx, peer.Key) s.ephemeralManager.OnPeerDisconnected(ctx, peer) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 25bcdfcee..1adf9a2d6 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -112,7 +112,7 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey st return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error { +func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peerPubKey string) error { // TODO implement me panic("implement me") } From 02f3105e480dddea1e3ac4085737ad7fc41459b0 Mon Sep 17 00:00:00 2001 From: Evgenii Date: Thu, 1 Aug 2024 10:56:18 +0100 Subject: [PATCH 24/53] Freebsd test all root component (#2361) * chore(tests): add all root component into FreeBSD check * change timeout for each component * add client tests execution measure * revert -p1 for client tests and explain why * measure duration of all test run --- .github/workflows/golang-test-freebsd.yml | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index 6cb5f7427..4f13ee30e 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -30,6 +30,17 @@ jobs: # -e - to faile on first error run: | set -e -x - go build -o netbird client/main.go - go test -timeout 5m -p 1 -failfast ./iface/... - go test -timeout 5m -p 1 -failfast ./client/... + time go build -o netbird client/main.go + # check all component except management, since we do not support management server on freebsd + time go test -timeout 1m -failfast ./base62/... + # NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use` + time go test -timeout 8m -failfast -p 1 ./client/... + time go test -timeout 1m -failfast ./dns/... + time go test -timeout 1m -failfast ./encryption/... + time go test -timeout 1m -failfast ./formatter/... + time go test -timeout 1m -failfast ./iface/... + time go test -timeout 1m -failfast ./route/... + time go test -timeout 1m -failfast ./sharedsock/... + time go test -timeout 1m -failfast ./signal/... + time go test -timeout 1m -failfast ./util/... + time go test -timeout 1m -failfast ./version/... From cbf9f2058e1410d9a1e32ca5a6e63c914b8cb590 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 1 Aug 2024 16:21:43 +0200 Subject: [PATCH 25/53] Use accountID retrieved from the sync call to acquire read lock sooner (#2369) Use accountID retrieved from the sync call to acquire read lock sooner and avoiding extra DB calls. - Use the account ID across sync calls - Moved account read lock - Renamed CancelPeerRoutines to OnPeerDisconnected - Added race tests --- management/server/account.go | 38 +--- management/server/grpcserver.go | 22 +-- management/server/management_proto_test.go | 183 +++++++++++++++++- management/server/mock_server/account_mock.go | 8 +- 4 files changed, 198 insertions(+), 53 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 4648c00cd..5d3ee6dc1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -135,8 +135,8 @@ type AccountManager interface { UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(account *Account) (map[string]struct{}, error) - SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CancelPeerRoutines(ctx context.Context, peerPubKey string) error + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) @@ -1857,22 +1857,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C } } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - // acquiring peer write lock here is ok since we only modify peer information that is supplied by the - // peer itself which can't be modified by API, and it only happens after an account read lock is acquired - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() - - accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) - if err != nil { - if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered") - } - return nil, nil, nil, err - } - +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { @@ -1892,22 +1881,11 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey return peer, netMap, postureChecks, nil } -func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peerPubKey string) error { - // acquiring peer write lock here is ok since we only modify peer information that is supplied by the - // peer itself which can't be modified by API, and it only happens after an account read lock is acquired - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() - - accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) - if err != nil { - if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - return status.Errorf(status.Unauthenticated, "peer not registered") - } - return err - } - +func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 4a12a5c3e..f71a45d99 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -156,7 +156,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) + peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) if err != nil { return mapError(ctx, err) } @@ -179,11 +179,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) } - return s.handleUpdates(ctx, peerKey, peer, updates, srv) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { for { select { // condition when there are some updates @@ -194,12 +194,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee if !open { log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) - s.cancelPeerRoutines(ctx, peer) + s.cancelPeerRoutines(ctx, accountID, peer) return nil } log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { return err } @@ -207,7 +207,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) - s.cancelPeerRoutines(ctx, peer) + s.cancelPeerRoutines(ctx, accountID, peer) return srv.Context().Err() } } @@ -215,10 +215,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { - s.cancelPeerRoutines(ctx, peer) + s.cancelPeerRoutines(ctx, accountID, peer) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.SendMsg(&proto.EncryptedMessage{ @@ -226,17 +226,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer * Body: encryptedResp, }) if err != nil { - s.cancelPeerRoutines(ctx, peer) + s.cancelPeerRoutines(ctx, accountID, peer) return status.Errorf(codes.Internal, "failed sending update message") } log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) return nil } -func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) { +func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID) - _ = s.accountManager.CancelPeerRoutines(ctx, peer.Key) + _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.ephemeralManager.OnPeerDisconnected(ctx, peer) } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index e1f7787f2..2c9d43948 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "fmt" "net" "os" "path/filepath" @@ -16,6 +17,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/util" @@ -83,7 +85,7 @@ func Test_SyncProtocol(t *testing.T) { defer func() { os.Remove(filepath.Join(dir, "store.json")) //nolint }() - mgmtServer, mgmtAddr, err := startManagement(t, &Config{ + mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -399,32 +401,35 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } } -func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) { +func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") if err != nil { - return nil, "", err + return nil, nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { - return nil, "", err + return nil, nil, "", err } t.Cleanup(cleanUp) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", + + ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck + + accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) if err != nil { - return nil, "", err + return nil, nil, "", err } turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) ephemeralMgr := NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) if err != nil { - return nil, "", err + return nil, nil, "", err } mgmtProto.RegisterManagementServiceServer(s, mgmtServer) @@ -434,7 +439,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) } }() - return s, lis.Addr().String(), nil + return s, accountManager, lis.Addr().String(), nil } func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) { @@ -454,3 +459,165 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn, nil } +func Test_SyncStatusRace(t *testing.T) { + if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { + t.Skip("Skipping on CI and Postgres store") + } + for i := 0; i < 500; i++ { + t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) { + testSyncStatusRace(t) + }) + } +} +func testSyncStatusRace(t *testing.T) { + t.Helper() + dir := t.TempDir() + err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) + if err != nil { + t.Fatal(err) + } + defer func() { + os.Remove(filepath.Join(dir, "store.json")) //nolint + }() + + mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{ + Stuns: []*Host{{ + Proto: "udp", + URI: "stun:stun.wiretrustee.com:3468", + }}, + TURNConfig: &TURNConfig{ + TimeBasedCredentials: false, + CredentialsTTL: util.Duration{}, + Secret: "whatever", + Turns: []*Host{{ + Proto: "udp", + URI: "turn:stun.wiretrustee.com:3468", + }}, + }, + Signal: &Host{ + Proto: "http", + URI: "signal.wiretrustee.com:10000", + }, + Datadir: dir, + HttpConfig: nil, + }) + if err != nil { + t.Fatal(err) + return + } + defer mgmtServer.GracefulStop() + + client, clientConn, err := createRawClient(mgmtAddr) + if err != nil { + t.Fatal(err) + return + } + + defer clientConn.Close() + + // there are two peers already in the store, add two more + peers, err := registerPeers(2, client) + if err != nil { + t.Fatal(err) + return + } + + serverKey, err := getServerKey(client) + if err != nil { + t.Fatal(err) + return + } + + concurrentPeerKey2 := peers[1] + t.Log("Public key of concurrent peer: ", concurrentPeerKey2.PublicKey().String()) + + syncReq2 := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + message2, err := encryption.EncryptMessage(*serverKey, *concurrentPeerKey2, syncReq2) + if err != nil { + t.Fatal(err) + return + } + + ctx2, cancelFunc2 := context.WithCancel(context.Background()) + + //client. + sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{ + WgPubKey: concurrentPeerKey2.PublicKey().String(), + Body: message2, + }) + if err != nil { + t.Fatal(err) + return + } + + resp2 := &mgmtProto.EncryptedMessage{} + err = sync2.RecvMsg(resp2) + if err != nil { + t.Fatal(err) + return + } + + peerWithInvalidStatus := peers[0] + t.Log("Public key of peer with invalid status: ", peerWithInvalidStatus.PublicKey().String()) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + message, err := encryption.EncryptMessage(*serverKey, *peerWithInvalidStatus, syncReq) + if err != nil { + t.Fatal(err) + return + } + + ctx, cancelFunc := context.WithCancel(context.Background()) + + //client. + sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{ + WgPubKey: peerWithInvalidStatus.PublicKey().String(), + Body: message, + }) + if err != nil { + t.Fatal(err) + return + } + + // take the first registered peer as a base for the test. Total four. + + resp := &mgmtProto.EncryptedMessage{} + err = sync.RecvMsg(resp) + if err != nil { + t.Fatal(err) + return + } + + cancelFunc2() + time.Sleep(1 * time.Millisecond) + cancelFunc() + time.Sleep(10 * time.Millisecond) + + ctx, cancelFunc = context.WithCancel(context.Background()) + defer cancelFunc() + sync, err = client.Sync(ctx, &mgmtProto.EncryptedMessage{ + WgPubKey: peerWithInvalidStatus.PublicKey().String(), + Body: message, + }) + if err != nil { + t.Fatal(err) + return + } + + resp = &mgmtProto.EncryptedMessage{} + err = sync.RecvMsg(resp) + if err != nil { + t.Fatal(err) + return + } + + time.Sleep(10 * time.Millisecond) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String()) + if err != nil { + t.Fatal(err) + return + } + if !peer.Status.Connected { + t.Fatal("Peer should be connected") + } +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 1adf9a2d6..a66bdee2b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -31,7 +31,7 @@ type MockAccountManager struct { ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error) @@ -105,14 +105,14 @@ type MockAccountManager struct { GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncAndMarkPeerFunc != nil { - return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP) + return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peerPubKey string) error { +func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { // TODO implement me panic("implement me") } From 0c8f8a62c75eca5f8d2e9496179e610f54f34a28 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 1 Aug 2024 16:46:55 +0200 Subject: [PATCH 26/53] Handling invalid UTF-8 character in sys info (#2360) In some operation systems, the sys info contains invalid characters. In this patch try to keep the original fallback logic but filter out the cases when the character is invalid. --- client/system/info_darwin_test.go | 5 +- client/system/info_linux.go | 63 +++++++-- client/system/sysinfo_linux.go | 12 ++ client/system/sysinfo_linux_test.go | 198 ++++++++++++++++++++++++++++ encryption/message.go | 2 +- 5 files changed, 264 insertions(+), 16 deletions(-) create mode 100644 client/system/sysinfo_linux.go create mode 100644 client/system/sysinfo_linux_test.go diff --git a/client/system/info_darwin_test.go b/client/system/info_darwin_test.go index 94e0b9e5e..5608bc776 100644 --- a/client/system/info_darwin_test.go +++ b/client/system/info_darwin_test.go @@ -1,11 +1,12 @@ package system import ( - log "github.com/sirupsen/logrus" "testing" + + log "github.com/sirupsen/logrus" ) -func Test_sysInfo(t *testing.T) { +func Test_sysInfoMac(t *testing.T) { t.Skip("skipping darwin test") serialNum, prodName, manufacturer := sysInfo() if serialNum == "" { diff --git a/client/system/info_linux.go b/client/system/info_linux.go index db58d913f..b6a142bce 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -21,6 +21,26 @@ import ( "github.com/netbirdio/netbird/version" ) +type SysInfoGetter interface { + GetSysInfo() SysInfo +} + +type SysInfoWrapper struct { + si sysinfo.SysInfo +} + +func (s SysInfoWrapper) GetSysInfo() SysInfo { + s.si.GetSysInfo() + return SysInfo{ + ChassisSerial: s.si.Chassis.Serial, + ProductSerial: s.si.Product.Serial, + BoardSerial: s.si.Board.Serial, + ProductName: s.si.Product.Name, + BoardName: s.si.Board.Name, + ProductVendor: s.si.Product.Vendor, + } +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { info := _getInfo() @@ -45,7 +65,8 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - serialNum, prodName, manufacturer := sysInfo() + si := SysInfoWrapper{} + serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo()) env := Environment{ Cloud: detect_cloud.Detect(ctx), @@ -87,20 +108,36 @@ func _getInfo() string { return out.String() } -func sysInfo() (serialNumber string, productName string, manufacturer string) { - var si sysinfo.SysInfo - si.GetSysInfo() +func sysInfo(si SysInfo) (string, string, string) { isascii := regexp.MustCompile("^[[:ascii:]]+$") - serial := si.Chassis.Serial - if (serial == "Default string" || serial == "") && si.Product.Serial != "" { - serial = si.Product.Serial + + serials := []string{si.ChassisSerial, si.ProductSerial} + serial := "" + + for _, s := range serials { + if isascii.MatchString(s) { + serial = s + if s != "Default string" { + break + } + } } - if (!isascii.MatchString(serial)) && si.Board.Serial != "" { - serial = si.Board.Serial + + if serial == "" && isascii.MatchString(si.BoardSerial) { + serial = si.BoardSerial } - name := si.Product.Name - if (!isascii.MatchString(name)) && si.Board.Name != "" { - name = si.Board.Name + + var name string + for _, n := range []string{si.ProductName, si.BoardName} { + if isascii.MatchString(n) { + name = n + break + } } - return serial, name, si.Product.Vendor + + var manufacturer string + if isascii.MatchString(si.ProductVendor) { + manufacturer = si.ProductVendor + } + return serial, name, manufacturer } diff --git a/client/system/sysinfo_linux.go b/client/system/sysinfo_linux.go new file mode 100644 index 000000000..df0f5574c --- /dev/null +++ b/client/system/sysinfo_linux.go @@ -0,0 +1,12 @@ +package system + +// SysInfo used to moc out the sysinfo getter +type SysInfo struct { + ChassisSerial string + ProductSerial string + BoardSerial string + + ProductName string + BoardName string + ProductVendor string +} diff --git a/client/system/sysinfo_linux_test.go b/client/system/sysinfo_linux_test.go new file mode 100644 index 000000000..f6a0b7058 --- /dev/null +++ b/client/system/sysinfo_linux_test.go @@ -0,0 +1,198 @@ +package system + +import "testing" + +func Test_sysInfo(t *testing.T) { + tests := []struct { + name string + sysInfo SysInfo + wantSerialNum string + wantProdName string + wantManufacturer string + }{ + { + name: "Test Case 1", + sysInfo: SysInfo{ + ChassisSerial: "Default string", + ProductSerial: "Default string", + BoardSerial: "M80-G8013200245", + ProductName: "B650M-HDV/M.2", + BoardName: "B650M-HDV/M.2", + ProductVendor: "ASRock", + }, + wantSerialNum: "Default string", + wantProdName: "B650M-HDV/M.2", + wantManufacturer: "ASRock", + }, + { + name: "Empty Chassis Serial", + sysInfo: SysInfo{ + ChassisSerial: "", + ProductSerial: "Default string", + BoardSerial: "M80-G8013200245", + ProductName: "B650M-HDV/M.2", + BoardName: "B650M-HDV/M.2", + ProductVendor: "ASRock", + }, + wantSerialNum: "Default string", + wantProdName: "B650M-HDV/M.2", + wantManufacturer: "ASRock", + }, + { + name: "Empty Chassis Serial", + sysInfo: SysInfo{ + ChassisSerial: "", + ProductSerial: "Default string", + BoardSerial: "M80-G8013200245", + ProductName: "B650M-HDV/M.2", + BoardName: "B650M-HDV/M.2", + ProductVendor: "ASRock", + }, + wantSerialNum: "Default string", + wantProdName: "B650M-HDV/M.2", + wantManufacturer: "ASRock", + }, + { + name: "Fallback to Product Serial", + sysInfo: SysInfo{ + ChassisSerial: "Default string", + ProductSerial: "Product serial", + BoardSerial: "M80-G8013200245", + ProductName: "B650M-HDV/M.2", + BoardName: "B650M-HDV/M.2", + ProductVendor: "ASRock", + }, + wantSerialNum: "Product serial", + wantProdName: "B650M-HDV/M.2", + wantManufacturer: "ASRock", + }, + { + name: "Fallback to Product Serial with default string", + sysInfo: SysInfo{ + ChassisSerial: "Default string", + ProductSerial: "Default string", + BoardSerial: "M80-G8013200245", + ProductName: "B650M-HDV/M.2", + BoardName: "B650M-HDV/M.2", + ProductVendor: "ASRock", + }, + wantSerialNum: "Default string", + wantProdName: "B650M-HDV/M.2", + wantManufacturer: "ASRock", + }, + { + name: "Non UTF-8 in Chassis Serial", + sysInfo: SysInfo{ + ChassisSerial: "\x80", + ProductSerial: "Product serial", + BoardSerial: "M80-G8013200245", + ProductName: "B650M-HDV/M.2", + BoardName: "B650M-HDV/M.2", + ProductVendor: "ASRock", + }, + wantSerialNum: "Product serial", + wantProdName: "B650M-HDV/M.2", + wantManufacturer: "ASRock", + }, + { + name: "Non UTF-8 in Chassis Serial and Product Serial", + sysInfo: SysInfo{ + ChassisSerial: "\x80", + ProductSerial: "\x80", + BoardSerial: "M80-G8013200245", + ProductName: "B650M-HDV/M.2", + BoardName: "B650M-HDV/M.2", + ProductVendor: "ASRock", + }, + wantSerialNum: "M80-G8013200245", + wantProdName: "B650M-HDV/M.2", + wantManufacturer: "ASRock", + }, + { + name: "Non UTF-8 in Chassis Serial and Product Serial and BoardSerial", + sysInfo: SysInfo{ + ChassisSerial: "\x80", + ProductSerial: "\x80", + BoardSerial: "\x80", + ProductName: "B650M-HDV/M.2", + BoardName: "B650M-HDV/M.2", + ProductVendor: "ASRock", + }, + wantSerialNum: "", + wantProdName: "B650M-HDV/M.2", + wantManufacturer: "ASRock", + }, + + { + name: "Empty Product Name", + sysInfo: SysInfo{ + ChassisSerial: "Default string", + ProductSerial: "Default string", + BoardSerial: "M80-G8013200245", + ProductName: "", + BoardName: "boardname", + ProductVendor: "ASRock", + }, + wantSerialNum: "Default string", + wantProdName: "boardname", + wantManufacturer: "ASRock", + }, + { + name: "Invalid Product Name", + sysInfo: SysInfo{ + ChassisSerial: "Default string", + ProductSerial: "Default string", + BoardSerial: "M80-G8013200245", + ProductName: "\x80", + BoardName: "boardname", + ProductVendor: "ASRock", + }, + wantSerialNum: "Default string", + wantProdName: "boardname", + wantManufacturer: "ASRock", + }, + { + name: "Invalid BoardName Name", + sysInfo: SysInfo{ + ChassisSerial: "Default string", + ProductSerial: "Default string", + BoardSerial: "M80-G8013200245", + ProductName: "\x80", + BoardName: "\x80", + ProductVendor: "ASRock", + }, + wantSerialNum: "Default string", + wantProdName: "", + wantManufacturer: "ASRock", + }, + { + name: "Invalid chars", + sysInfo: SysInfo{ + ChassisSerial: "\x80", + ProductSerial: "\x80", + BoardSerial: "\x80", + ProductName: "\x80", + BoardName: "\x80", + ProductVendor: "\x80", + }, + wantSerialNum: "", + wantProdName: "", + wantManufacturer: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo) + if gotSerialNum != tt.wantSerialNum { + t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum) + } + if gotProdName != tt.wantProdName { + t.Errorf("sysInfo() gotProdName = %v, want %v", gotProdName, tt.wantProdName) + } + if gotManufacturer != tt.wantManufacturer { + t.Errorf("sysInfo() gotManufacturer = %v, want %v", gotManufacturer, tt.wantManufacturer) + } + }) + } +} diff --git a/encryption/message.go b/encryption/message.go index a646fa679..6e4cd7391 100644 --- a/encryption/message.go +++ b/encryption/message.go @@ -10,7 +10,7 @@ import ( func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) { byteResp, err := pb.Marshal(message) if err != nil { - log.Errorf("failed marshalling message %v", err) + log.Errorf("failed marshalling message %v, %+v", err, message.String()) return nil, err } From 3506ac4234f00ddced815f0d237bac08516ce770 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 1 Aug 2024 17:13:58 +0200 Subject: [PATCH 27/53] When creating new setup key, "revoked" field doesn't do anything (#2357) Remove unused field from API --- management/server/http/api/openapi.yml | 39 ++++++++++++++++++++++++- management/server/http/api/types.gen.go | 23 ++++++++++++++- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 30cb19c0c..45887dc2e 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -526,6 +526,43 @@ components: - revoked - auto_groups - usage_limit + CreateSetupKeyRequest: + type: object + properties: + name: + description: Setup Key name + type: string + example: Default key + type: + description: Setup key type, one-off for single time usage and reusable + type: string + example: reusable + expires_in: + description: Expiration time in seconds + type: integer + minimum: 86400 + maximum: 31536000 + example: 86400 + auto_groups: + description: List of group IDs to auto-assign to peers registered with this key + type: array + items: + type: string + example: "ch8i4ug6lnn4g9hqv7m0" + usage_limit: + description: A number of times this key can be used. The value of 0 indicates the unlimited usage. + type: integer + example: 0 + ephemeral: + description: Indicate that the peer will be ephemeral or not + type: boolean + example: true + required: + - name + - type + - expires_in + - auto_groups + - usage_limit PersonalAccessToken: type: object properties: @@ -1806,7 +1843,7 @@ paths: content: 'application/json': schema: - $ref: '#/components/schemas/SetupKeyRequest' + $ref: '#/components/schemas/CreateSetupKeyRequest' responses: '200': description: A Setup Keys Object diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index f731356ee..77a6c643d 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -254,6 +254,27 @@ type Country struct { // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country type CountryCode = string +// CreateSetupKeyRequest defines model for CreateSetupKeyRequest. +type CreateSetupKeyRequest struct { + // AutoGroups List of group IDs to auto-assign to peers registered with this key + AutoGroups []string `json:"auto_groups"` + + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral *bool `json:"ephemeral,omitempty"` + + // ExpiresIn Expiration time in seconds + ExpiresIn int `json:"expires_in"` + + // Name Setup Key name + Name string `json:"name"` + + // Type Setup key type, one-off for single time usage and reusable + Type string `json:"type"` + + // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. + UsageLimit int `json:"usage_limit"` +} + // DNSSettings defines model for DNSSettings. type DNSSettings struct { // DisabledManagementGroups Groups whose DNS management is disabled @@ -1241,7 +1262,7 @@ type PostApiRoutesJSONRequestBody = RouteRequest type PutApiRoutesRouteIdJSONRequestBody = RouteRequest // PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType. -type PostApiSetupKeysJSONRequestBody = SetupKeyRequest +type PostApiSetupKeysJSONRequestBody = CreateSetupKeyRequest // PutApiSetupKeysKeyIdJSONRequestBody defines body for PutApiSetupKeysKeyId for application/json ContentType. type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest From df8b8db068c15c23a56cd650b5d9226d93d93149 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 1 Aug 2024 17:20:15 +0200 Subject: [PATCH 28/53] Bump github.com/docker/docker (#2356) Bumps [github.com/docker/docker](https://github.com/docker/docker) from 26.1.3+incompatible to 26.1.4+incompatible. - [Release notes](https://github.com/docker/docker/releases) - [Commits](https://github.com/docker/docker/compare/v26.1.3...v26.1.4) --- updated-dependencies: - dependency-name: github.com/docker/docker dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 1da44da3b..7f242c203 100644 --- a/go.mod +++ b/go.mod @@ -115,7 +115,7 @@ require ( github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect - github.com/docker/docker v26.1.3+incompatible // indirect + github.com/docker/docker v26.1.4+incompatible // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect diff --git a/go.sum b/go.sum index 842311344..1251a6fd7 100644 --- a/go.sum +++ b/go.sum @@ -81,8 +81,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/docker/docker v26.1.3+incompatible h1:lLCzRbrVZrljpVNobJu1J2FHk8V0s4BawoZippkc+xo= -github.com/docker/docker v26.1.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU= +github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= From 24e031ab74112e7dc8fb6733ed1fc38517545794 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ko=C5=82odziejczak?= <31549762+mrl5@users.noreply.github.com> Date: Thu, 1 Aug 2024 18:22:02 +0200 Subject: [PATCH 29/53] Fix syslog output containing duplicated timestamps (#2292) ```console journalctl ``` ```diff - Jul 19 14:41:01 rpi /usr/bin/netbird[614]: 2024-07-19T14:41:01+02:00 ERRO %!s(): error while handling message of Peer [key: REDACTED] error: [wrongly addressed message REDACTED] - Jul 19 21:53:03 rpi /usr/bin/netbird[614]: 2024-07-19T21:53:03+02:00 WARN %!s(): disconnected from the Signal service but will retry silently. Reason: rpc error: code = Internal desc = server closed the stream without sending trailers - Jul 19 21:53:04 rpi /usr/bin/netbird[614]: 2024-07-19T21:53:04+02:00 INFO %!s(): connected to the Signal Service stream - Jul 19 22:24:10 rpi /usr/bin/netbird[614]: 2024-07-19T22:24:10+02:00 WARN [error: read udp 192.168.1.11:48398->9.9.9.9:53: i/o timeout, upstream: 9.9.9.9:53] %!s(): got an error while connecting to upstream + Jul 19 14:41:01 rpi /usr/bin/netbird[614]: error while handling message of Peer [key: REDACTED] error: [wrongly addressed message REDACTED] + Jul 19 21:53:03 rpi /usr/bin/netbird[614]: disconnected from the Signal service but will retry silently. Reason: rpc error: code = Internal desc = server closed the stream without sending trailers + Jul 19 21:53:04 rpi /usr/bin/netbird[614]: connected to the Signal Service stream + Jul 19 22:24:10 rpi /usr/bin/netbird[614]: [error: read udp 192.168.1.11:48398->9.9.9.9:53: i/o timeout, upstream: 9.9.9.9:53] got an error while connecting to upstream ``` please notice that although log level is no longer present in the syslog message it is still respected by syslog logger, so the log levels are not lost: ```console journalctl -p 3 ``` ```diff - Jul 19 14:41:01 rpi /usr/bin/netbird[614]: 2024-07-19T14:41:01+02:00 ERRO %!s(): error while handling message of Peer [key: REDACTED] error: [wrongly addressed message REDACTED] + Jul 19 14:41:01 rpi /usr/bin/netbird[614]: error while handling message of Peer [key: REDACTED] error: [wrongly addressed message REDACTED] ``` --- .editorconfig | 8 ++++++++ formatter/formatter.go | 34 +++++++++++++++++++++++++++++++++- formatter/formatter_test.go | 19 ++++++++++++++++++- formatter/set.go | 6 ++++++ util/log.go | 3 +++ 5 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 .editorconfig diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..3dcb869d2 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,8 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true + +[*.go] +indent_style = tab diff --git a/formatter/formatter.go b/formatter/formatter.go index a37c67914..74de38603 100644 --- a/formatter/formatter.go +++ b/formatter/formatter.go @@ -14,14 +14,29 @@ type TextFormatter struct { levelDesc []string } +// SyslogFormatter formats logs into text +type SyslogFormatter struct { + levelDesc []string +} + +var validLevelDesc = []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"} + + // NewTextFormatter create new MyTextFormatter instance func NewTextFormatter() *TextFormatter { return &TextFormatter{ - levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}, + levelDesc: validLevelDesc, timestampFormat: time.RFC3339, // or RFC3339 } } +// NewSyslogFormatter create new MySyslogFormatter instance +func NewSyslogFormatter() *SyslogFormatter { + return &SyslogFormatter{ + levelDesc: validLevelDesc, + } +} + // Format renders a single log entry func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) { var fields string @@ -49,3 +64,20 @@ func (f *TextFormatter) parseLevel(level logrus.Level) string { return f.levelDesc[level] } + +// Format renders a single log entry +func (f *SyslogFormatter) Format(entry *logrus.Entry) ([]byte, error) { + var fields string + keys := make([]string, 0, len(entry.Data)) + for k, v := range entry.Data { + if k == "source" { + continue + } + keys = append(keys, fmt.Sprintf("%s: %v", k, v)) + } + + if len(keys) > 0 { + fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", ")) + } + return []byte(fmt.Sprintf("%s%s\n", fields, entry.Message)), nil +} diff --git a/formatter/formatter_test.go b/formatter/formatter_test.go index 54bc8a756..1ed207958 100644 --- a/formatter/formatter_test.go +++ b/formatter/formatter_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestLogMessageFormat(t *testing.T) { +func TestLogTextFormat(t *testing.T) { someEntry := &logrus.Entry{ Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"}, @@ -24,3 +24,20 @@ func TestLogMessageFormat(t *testing.T) { expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$" assert.Regexp(t, expectedString, parsedString) } + +func TestLogSyslogFormat(t *testing.T) { + + someEntry := &logrus.Entry{ + Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"}, + Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC), + Level: 3, + Message: "Some Message", + } + + formatter := NewSyslogFormatter() + result, _ := formatter.Format(someEntry) + + parsedString := string(result) + expectedString := "^\\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] Some Message\\s+$" + assert.Regexp(t, expectedString, parsedString) +} diff --git a/formatter/set.go b/formatter/set.go index f9ccef601..9dfea5a7f 100644 --- a/formatter/set.go +++ b/formatter/set.go @@ -10,6 +10,12 @@ func SetTextFormatter(logger *logrus.Logger) { logger.ReportCaller = true logger.AddHook(NewContextHook()) } +// SetSyslogFormatter set the text formatter for given logger. +func SetSyslogFormatter(logger *logrus.Logger) { + logger.Formatter = NewSyslogFormatter() + logger.ReportCaller = true + logger.AddHook(NewContextHook()) +} // SetJSONFormatter set the JSON formatter for given logger. func SetJSONFormatter(logger *logrus.Logger) { diff --git a/util/log.go b/util/log.go index 74b99311e..4bce75e4a 100644 --- a/util/log.go +++ b/util/log.go @@ -35,8 +35,11 @@ func InitLog(logLevel string, logPath string) error { AddSyslogHook() } + //nolint:gocritic if os.Getenv("NB_LOG_FORMAT") == "json" { formatter.SetJSONFormatter(log.StandardLogger()) + } else if logPath == "syslog" { + formatter.SetSyslogFormatter(log.StandardLogger()) } else { formatter.SetTextFormatter(log.StandardLogger()) } From 57624203c94c686ee538f7e8f6d076967d6e8a6b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 1 Aug 2024 18:38:19 +0200 Subject: [PATCH 30/53] Allow route updates even if some domains failed resolution (#2368) --- client/internal/routemanager/dynamic/route.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 8429b4534..e95710798 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -189,9 +189,14 @@ func (r *Route) startResolver(ctx context.Context) { } func (r *Route) update(ctx context.Context) error { - if resolved, err := r.resolveDomains(); err != nil { - return fmt.Errorf("resolve domains: %w", err) - } else if err := r.updateDynamicRoutes(ctx, resolved); err != nil { + resolved, err := r.resolveDomains() + if err != nil { + if len(resolved) == 0 { + return fmt.Errorf("resolve domains: %w", err) + } + log.Warnf("Failed to resolve domains: %v", err) + } + if err := r.updateDynamicRoutes(ctx, resolved); err != nil { return fmt.Errorf("update dynamic routes: %w", err) } From 216d9f2ee8bd0c7cbcd6e0bf3c4318aa8785c10e Mon Sep 17 00:00:00 2001 From: keacwu Date: Fri, 2 Aug 2024 00:52:38 +0800 Subject: [PATCH 31/53] Adding geolocation download log message. (#2085) * Adding geolocation download prompt message. * import log file and remove unnecessary else --------- Co-authored-by: Maycon Santos --- management/server/geolocation/database.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/management/server/geolocation/database.go b/management/server/geolocation/database.go index 1bada6075..c9b2eafff 100644 --- a/management/server/geolocation/database.go +++ b/management/server/geolocation/database.go @@ -9,6 +9,7 @@ import ( "path" "strconv" + log "github.com/sirupsen/logrus" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -30,6 +31,8 @@ func loadGeolocationDatabases(dataDir string) error { continue } + log.Infof("geo location file %s not found , file will be downloaded", file) + switch file { case MMDBFileName: extractFunc := func(src string, dst string) error { From f84b606506f6082581e4bb1983dd31bb1f9ffbd9 Mon Sep 17 00:00:00 2001 From: David Fry Date: Thu, 1 Aug 2024 18:52:50 +0200 Subject: [PATCH 32/53] add extra auth audience (#2350) --- management/server/config.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/management/server/config.go b/management/server/config.go index 3e5ff1eaf..4efe4fe74 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -56,6 +56,10 @@ type Config struct { func (c Config) GetAuthAudiences() []string { audiences := []string{c.HttpConfig.AuthAudience} + if c.HttpConfig.ExtraAuthAudience != "" { + audiences = append(audiences, c.HttpConfig.ExtraAuthAudience) + } + if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" { audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience) } @@ -90,6 +94,8 @@ type HttpServerConfig struct { OIDCConfigEndpoint string // IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not IdpSignKeyRefreshEnabled bool + // Extra audience + ExtraAuthAudience string } // Host represents a Wiretrustee host (e.g. STUN, TURN, Signal) From 5ad4ae769a9e9629b0ed52e2f8eeec0f03812639 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:47:12 +0200 Subject: [PATCH 33/53] Extend client debug bundle (#2341) Adds readme (with --anonymize) Fixes archive file timestamps Adds routes info Adds interfaces Adds client config --- client/anonymize/anonymize.go | 15 + client/cmd/debug.go | 31 +- client/cmd/root.go | 4 + client/cmd/status.go | 17 +- .../routemanager/systemops/systemops_bsd.go | 2 +- .../systemops/systemops_generic.go | 4 +- .../routemanager/systemops/systemops_linux.go | 4 +- .../systemops/systemops_nonlinux.go | 2 +- .../systemops/systemops_windows.go | 2 +- client/proto/daemon.pb.go | 18 +- client/proto/daemon.proto | 1 + client/server/debug.go | 431 +++++++++++++++--- 12 files changed, 435 insertions(+), 96 deletions(-) diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index acbd0441e..208e74d53 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -178,6 +178,21 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string { }) } +// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and +// domain names with random strings. +func (a *Anonymizer) AnonymizeRoute(route string) string { + prefix, err := netip.ParsePrefix(route) + if err == nil { + ip := a.AnonymizeIPString(prefix.Addr().String()) + return fmt.Sprintf("%s/%d", ip, prefix.Bits()) + } + domains := strings.Split(route, ", ") + for i, domain := range domains { + domains[i] = a.AnonymizeDomain(domain) + } + return strings.Join(domains, ", ") +} + func isWellKnown(addr netip.Addr) bool { wellKnown := []string{ "8.8.8.8", "8.8.4.4", // Google DNS IPv4 diff --git a/client/cmd/debug.go b/client/cmd/debug.go index da5e0945a..dbdce91ab 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc/status" @@ -13,6 +14,8 @@ import ( "github.com/netbirdio/netbird/client/server" ) +const errCloseConnection = "Failed to close connection: %v" + var debugCmd = &cobra.Command{ Use: "debug", Short: "Debugging commands", @@ -63,12 +66,17 @@ func debugBundle(cmd *cobra.Command, _ []string) error { if err != nil { return err } - defer conn.Close() + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() client := proto.NewDaemonServiceClient(conn) resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ - Anonymize: anonymizeFlag, - Status: getStatusOutput(cmd), + Anonymize: anonymizeFlag, + Status: getStatusOutput(cmd), + SystemInfo: debugSystemInfoFlag, }) if err != nil { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) @@ -84,7 +92,11 @@ func setLogLevel(cmd *cobra.Command, args []string) error { if err != nil { return err } - defer conn.Close() + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() client := proto.NewDaemonServiceClient(conn) level := server.ParseLogLevel(args[0]) @@ -113,7 +125,11 @@ func runForDuration(cmd *cobra.Command, args []string) error { if err != nil { return err } - defer conn.Close() + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() client := proto.NewDaemonServiceClient(conn) @@ -189,8 +205,9 @@ func runForDuration(cmd *cobra.Command, args []string) error { cmd.Println("Creating debug bundle...") resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ - Anonymize: anonymizeFlag, - Status: statusOutput, + Anonymize: anonymizeFlag, + Status: statusOutput, + SystemInfo: debugSystemInfoFlag, }) if err != nil { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) diff --git a/client/cmd/root.go b/client/cmd/root.go index 1e5c56366..010ffb25a 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -37,6 +37,7 @@ const ( serverSSHAllowedFlag = "allow-server-ssh" extraIFaceBlackListFlag = "extra-iface-blacklist" dnsRouteIntervalFlag = "dns-router-interval" + systemInfoFlag = "system-info" ) var ( @@ -69,6 +70,7 @@ var ( autoConnectDisabled bool extraIFaceBlackList []string anonymizeFlag bool + debugSystemInfoFlag bool dnsRouteInterval time.Duration rootCmd = &cobra.Command{ @@ -165,6 +167,8 @@ func init() { upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") + + debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle") } // SetupCloseHandler handles SIGTERM signal and exits with success diff --git a/client/cmd/status.go b/client/cmd/status.go index e6c7b8be8..d9b7a9c91 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -807,7 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) { } for i, route := range peer.Routes { - peer.Routes[i] = anonymizeRoute(a, route) + peer.Routes[i] = a.AnonymizeRoute(route) } } @@ -843,21 +843,8 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) } for i, route := range overview.Routes { - overview.Routes[i] = anonymizeRoute(a, route) + overview.Routes[i] = a.AnonymizeRoute(route) } overview.FQDN = a.AnonymizeDomain(overview.FQDN) } - -func anonymizeRoute(a *anonymize.Anonymizer, route string) string { - prefix, err := netip.ParsePrefix(route) - if err == nil { - ip := a.AnonymizeIPString(prefix.Addr().String()) - return fmt.Sprintf("%s/%d", ip, prefix.Bits()) - } - domains := strings.Split(route, ", ") - for i, domain := range domains { - domains[i] = a.AnonymizeDomain(domain) - } - return strings.Join(domains, ", ") -} diff --git a/client/internal/routemanager/systemops/systemops_bsd.go b/client/internal/routemanager/systemops/systemops_bsd.go index b7fb554db..5e3b20a86 100644 --- a/client/internal/routemanager/systemops/systemops_bsd.go +++ b/client/internal/routemanager/systemops/systemops_bsd.go @@ -22,7 +22,7 @@ type Route struct { Interface *net.Interface } -func getRoutesFromTable() ([]netip.Prefix, error) { +func GetRoutesFromTable() ([]netip.Prefix, error) { tab, err := retryFetchRIB() if err != nil { return nil, fmt.Errorf("fetch RIB: %v", err) diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 615e1b528..671545b86 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -427,7 +427,7 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { } func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() + routes, err := GetRoutesFromTable() if err != nil { return false, fmt.Errorf("get routes from table: %w", err) } @@ -440,7 +440,7 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) { } func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() + routes, err := GetRoutesFromTable() if err != nil { return false, fmt.Errorf("get routes from table: %w", err) } diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index c4f69fba5..2d0c57826 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -206,7 +206,7 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error return nil } -func getRoutesFromTable() ([]netip.Prefix, error) { +func GetRoutesFromTable() ([]netip.Prefix, error) { v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) if err != nil { return nil, fmt.Errorf("get v4 routes: %w", err) @@ -504,7 +504,7 @@ func getAddressFamily(prefix netip.Prefix) int { func hasSeparateRouting() ([]netip.Prefix, error) { if isLegacy() { - return getRoutesFromTable() + return GetRoutesFromTable() } return nil, ErrRoutingIsSeparate } diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 0adeb0992..3b52fc7af 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -24,5 +24,5 @@ func EnableIPForwarding() error { } func hasSeparateRouting() ([]netip.Prefix, error) { - return getRoutesFromTable() + return GetRoutesFromTable() } diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 88bdce7c9..0d3630cb8 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -94,7 +94,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro return nil } -func getRoutesFromTable() ([]netip.Prefix, error) { +func GetRoutesFromTable() ([]netip.Prefix, error) { mux.Lock() defer mux.Unlock() diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 813540246..fb10a38d3 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1828,8 +1828,9 @@ type DebugBundleRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"` - Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"` + Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"` } func (x *DebugBundleRequest) Reset() { @@ -1878,6 +1879,13 @@ func (x *DebugBundleRequest) GetStatus() string { return "" } +func (x *DebugBundleRequest) GetSystemInfo() bool { + if x != nil { + return x.SystemInfo + } + return false +} + type DebugBundleResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -2370,11 +2378,13 @@ var file_daemon_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, - 0x02, 0x38, 0x01, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, + 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, + 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 267eec279..43c379fb5 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -263,6 +263,7 @@ message Route { message DebugBundleRequest { bool anonymize = 1; string status = 2; + bool systemInfo = 3; } message DebugBundleResponse { diff --git a/client/server/debug.go b/client/server/debug.go index 9b6a52659..1187f3187 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -1,3 +1,5 @@ +//go:build !android && !ios + package server import ( @@ -6,16 +8,70 @@ import ( "context" "fmt" "io" + "net" + "net/netip" "os" + "sort" "strings" + "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/proto" ) +const readmeContent = `Netbird debug bundle +This debug bundle contains the following files: + +status.txt: Anonymized status information of the NetBird client. +client.log: Most recent, anonymized log file of the NetBird client. +routes.txt: Anonymized system routes, if --system-info flag was provided. +interfaces.txt: Anonymized network interface information, if --system-info flag was provided. +config.txt: Anonymized configuration information of the NetBird client. + + +Anonymization Process +The files in this bundle have been anonymized to protect sensitive information. Here's how the anonymization was applied: + +IP Addresses + +IPv4 addresses are replaced with addresses starting from 192.51.100.0 +IPv6 addresses are replaced with addresses starting from 100:: + +IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.). +Reoccuring IP addresses are replaced with the same anonymized address. + +Note: The anonymized IP addresses in the status file do not match those in the log and routes files. However, the anonymized IP addresses are consistent within the status file and across the routes and log files. + +Domains +All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle. +Reoccuring domain names are replaced with the same anonymized domain. + +Routes +For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. +Network Interfaces +The interfaces.txt file contains information about network interfaces, including: +- Interface name +- Interface index +- MTU (Maximum Transmission Unit) +- Flags +- IP addresses associated with each interface + +The IP addresses in the interfaces file are anonymized using the same process as described above. Interface names, indexes, MTUs, and flags are not anonymized. + +Configuration +The config.txt file contains anonymized configuration information of the NetBird client. Sensitive information such as private keys and SSH keys are excluded. The following fields are anonymized: +- ManagementURL +- AdminURL +- NATExternalIPs +- CustomDNSAddress + +Other non-sensitive configuration options are included without anonymization. +` + // DebugBundle creates a debug bundle and returns the location. func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { s.mutex.Lock() @@ -30,93 +86,211 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( return nil, fmt.Errorf("create zip file: %w", err) } defer func() { - if err := bundlePath.Close(); err != nil { - log.Errorf("failed to close zip file: %v", err) + if closeErr := bundlePath.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("close zip file: %w", closeErr) } if err != nil { - if err2 := os.Remove(bundlePath.Name()); err2 != nil { - log.Errorf("Failed to remove zip file: %v", err2) + if removeErr := os.Remove(bundlePath.Name()); removeErr != nil { + log.Errorf("Failed to remove zip file: %v", removeErr) } } }() - archive := zip.NewWriter(bundlePath) - defer func() { - if err := archive.Close(); err != nil { - log.Errorf("failed to close archive writer: %v", err) - } - }() - - if status := req.GetStatus(); status != "" { - filename := "status.txt" - if req.GetAnonymize() { - filename = "status.anon.txt" - } - statusReader := strings.NewReader(status) - if err := addFileToZip(archive, statusReader, filename); err != nil { - return nil, fmt.Errorf("add status file to zip: %w", err) - } - } - - logFile, err := os.Open(s.logFile) - if err != nil { - return nil, fmt.Errorf("open log file: %w", err) - } - defer func() { - if err := logFile.Close(); err != nil { - log.Errorf("failed to close original log file: %v", err) - } - }() - - filename := "client.log.txt" - var logReader io.Reader - errChan := make(chan error, 1) - if req.GetAnonymize() { - filename = "client.anon.log.txt" - var writer io.WriteCloser - logReader, writer = io.Pipe() - - go s.anonymize(logFile, writer, errChan) - } else { - logReader = logFile - } - if err := addFileToZip(archive, logReader, filename); err != nil { - return nil, fmt.Errorf("add log file to zip: %w", err) - } - - select { - case err := <-errChan: - if err != nil { - return nil, err - } - default: + if err := s.createArchive(bundlePath, req); err != nil { + return nil, err } return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil } -func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) { - scanner := bufio.NewScanner(reader) - anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) +func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleRequest) error { + archive := zip.NewWriter(bundlePath) + if err := s.addReadme(req, archive); err != nil { + return fmt.Errorf("add readme: %w", err) + } + if err := s.addStatus(req, archive); err != nil { + return fmt.Errorf("add status: %w", err) + } + + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) status := s.statusRecorder.GetFullStatus() seedFromStatus(anonymizer, &status) + if err := s.addConfig(req, anonymizer, archive); err != nil { + return fmt.Errorf("add config: %w", err) + } + + if req.GetSystemInfo() { + if err := s.addRoutes(req, anonymizer, archive); err != nil { + return fmt.Errorf("add routes: %w", err) + } + + if err := s.addInterfaces(req, anonymizer, archive); err != nil { + return fmt.Errorf("add interfaces: %w", err) + } + } + + if err := s.addLogfile(req, anonymizer, archive); err != nil { + return fmt.Errorf("add log file: %w", err) + } + + if err := archive.Close(); err != nil { + return fmt.Errorf("close archive writer: %w", err) + } + return nil +} + +func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error { + if req.GetAnonymize() { + readmeReader := strings.NewReader(readmeContent) + if err := addFileToZip(archive, readmeReader, "README.txt"); err != nil { + return fmt.Errorf("add README file to zip: %w", err) + } + } + return nil +} + +func (s *Server) addStatus(req *proto.DebugBundleRequest, archive *zip.Writer) error { + if status := req.GetStatus(); status != "" { + statusReader := strings.NewReader(status) + if err := addFileToZip(archive, statusReader, "status.txt"); err != nil { + return fmt.Errorf("add status file to zip: %w", err) + } + } + return nil +} + +func (s *Server) addConfig(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + var configContent strings.Builder + s.addCommonConfigFields(&configContent) + + if req.GetAnonymize() { + if s.config.ManagementURL != nil { + configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", anonymizer.AnonymizeURI(s.config.ManagementURL.String()))) + } + if s.config.AdminURL != nil { + configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", anonymizer.AnonymizeURI(s.config.AdminURL.String()))) + } + configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", anonymizeNATExternalIPs(s.config.NATExternalIPs, anonymizer))) + if s.config.CustomDNSAddress != "" { + configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", anonymizer.AnonymizeString(s.config.CustomDNSAddress))) + } + } else { + if s.config.ManagementURL != nil { + configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", s.config.ManagementURL.String())) + } + if s.config.AdminURL != nil { + configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", s.config.AdminURL.String())) + } + configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", s.config.NATExternalIPs)) + if s.config.CustomDNSAddress != "" { + configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", s.config.CustomDNSAddress)) + } + } + + // Add config content to zip file + configReader := strings.NewReader(configContent.String()) + if err := addFileToZip(archive, configReader, "config.txt"); err != nil { + return fmt.Errorf("add config file to zip: %w", err) + } + + return nil +} + +func (s *Server) addCommonConfigFields(configContent *strings.Builder) { + configContent.WriteString("NetBird Client Configuration:\n\n") + + // Add non-sensitive fields + configContent.WriteString(fmt.Sprintf("WgIface: %s\n", s.config.WgIface)) + configContent.WriteString(fmt.Sprintf("WgPort: %d\n", s.config.WgPort)) + if s.config.NetworkMonitor != nil { + configContent.WriteString(fmt.Sprintf("NetworkMonitor: %v\n", *s.config.NetworkMonitor)) + } + configContent.WriteString(fmt.Sprintf("IFaceBlackList: %v\n", s.config.IFaceBlackList)) + configContent.WriteString(fmt.Sprintf("DisableIPv6Discovery: %v\n", s.config.DisableIPv6Discovery)) + configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", s.config.RosenpassEnabled)) + configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", s.config.RosenpassPermissive)) + if s.config.ServerSSHAllowed != nil { + configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *s.config.ServerSSHAllowed)) + } + configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", s.config.DisableAutoConnect)) + configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", s.config.DNSRouteInterval)) +} + +func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + if routes, err := systemops.GetRoutesFromTable(); err != nil { + log.Errorf("Failed to get routes: %v", err) + } else { + // TODO: get routes including nexthop + routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer) + routesReader := strings.NewReader(routesContent) + if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil { + return fmt.Errorf("add routes file to zip: %w", err) + } + } + return nil +} + +func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + interfaces, err := net.Interfaces() + if err != nil { + return fmt.Errorf("get interfaces: %w", err) + } + + interfacesContent := formatInterfaces(interfaces, req.GetAnonymize(), anonymizer) + interfacesReader := strings.NewReader(interfacesContent) + if err := addFileToZip(archive, interfacesReader, "interfaces.txt"); err != nil { + return fmt.Errorf("add interfaces file to zip: %w", err) + } + + return nil +} + +func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) { + logFile, err := os.Open(s.logFile) + if err != nil { + return fmt.Errorf("open log file: %w", err) + } defer func() { - if err := writer.Close(); err != nil { - log.Errorf("Failed to close writer: %v", err) + if err := logFile.Close(); err != nil { + log.Errorf("Failed to close original log file: %v", err) } }() + + var logReader io.Reader + if req.GetAnonymize() { + var writer *io.PipeWriter + logReader, writer = io.Pipe() + + go s.anonymize(logFile, writer, anonymizer) + } else { + logReader = logFile + } + if err := addFileToZip(archive, logReader, "client.log"); err != nil { + return fmt.Errorf("add log file to zip: %w", err) + } + + return nil +} + +func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { + defer func() { + // always nil + _ = writer.Close() + }() + + scanner := bufio.NewScanner(reader) for scanner.Scan() { line := anonymizer.AnonymizeString(scanner.Text()) if _, err := writer.Write([]byte(line + "\n")); err != nil { - errChan <- fmt.Errorf("write line to writer: %w", err) + writer.CloseWithError(fmt.Errorf("anonymize write: %w", err)) return } } if err := scanner.Err(); err != nil { - errChan <- fmt.Errorf("read line from scanner: %w", err) + writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) return } } @@ -141,8 +315,22 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) ( func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error { header := &zip.FileHeader{ - Name: filename, - Method: zip.Deflate, + Name: filename, + Method: zip.Deflate, + Modified: time.Now(), + + CreatorVersion: 20, // Version 2.0 + ReaderVersion: 20, // Version 2.0 + Flags: 0x800, // UTF-8 filename + } + + // If the reader is a file, we can get more accurate information + if f, ok := reader.(*os.File); ok { + if stat, err := f.Stat(); err != nil { + log.Tracef("Failed to get file stat for %s: %v", filename, err) + } else { + header.Modified = stat.ModTime() + } } writer, err := archive.CreateHeader(header) @@ -165,6 +353,13 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) { for _, peer := range status.Peers { a.AnonymizeDomain(peer.FQDN) + for route := range peer.GetRoutes() { + a.AnonymizeRoute(route) + } + } + + for route := range status.LocalPeerState.Routes { + a.AnonymizeRoute(route) } for _, nsGroup := range status.NSGroupStates { @@ -179,3 +374,113 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) { } } } + +func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string { + var ipv4Routes, ipv6Routes []netip.Prefix + + // Separate IPv4 and IPv6 routes + for _, route := range routes { + if route.Addr().Is4() { + ipv4Routes = append(ipv4Routes, route) + } else { + ipv6Routes = append(ipv6Routes, route) + } + } + + // Sort IPv4 and IPv6 routes separately + sort.Slice(ipv4Routes, func(i, j int) bool { + return ipv4Routes[i].Bits() > ipv4Routes[j].Bits() + }) + sort.Slice(ipv6Routes, func(i, j int) bool { + return ipv6Routes[i].Bits() > ipv6Routes[j].Bits() + }) + + var builder strings.Builder + + // Format IPv4 routes + builder.WriteString("IPv4 Routes:\n") + for _, route := range ipv4Routes { + formatRoute(&builder, route, anonymize, anonymizer) + } + + // Format IPv6 routes + builder.WriteString("\nIPv6 Routes:\n") + for _, route := range ipv6Routes { + formatRoute(&builder, route, anonymize, anonymizer) + } + + return builder.String() +} + +func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) { + if anonymize { + anonymizedIP := anonymizer.AnonymizeIP(route.Addr()) + builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits())) + } else { + builder.WriteString(fmt.Sprintf("%s\n", route)) + } +} + +func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string { + sort.Slice(interfaces, func(i, j int) bool { + return interfaces[i].Name < interfaces[j].Name + }) + + var builder strings.Builder + builder.WriteString("Network Interfaces:\n") + + for _, iface := range interfaces { + builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name)) + builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index)) + builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU)) + builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags)) + + addrs, err := iface.Addrs() + if err != nil { + builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err)) + } else { + builder.WriteString(" Addresses:\n") + for _, addr := range addrs { + prefix, err := netip.ParsePrefix(addr.String()) + if err != nil { + builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err)) + continue + } + ip := prefix.Addr() + if anonymize { + ip = anonymizer.AnonymizeIP(ip) + } + builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits())) + } + } + } + + return builder.String() +} + +func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string { + anonymizedIPs := make([]string, len(ips)) + for i, ip := range ips { + parts := strings.SplitN(ip, "/", 2) + + ip1, err := netip.ParseAddr(parts[0]) + if err != nil { + anonymizedIPs[i] = ip + continue + } + ip1anon := anonymizer.AnonymizeIP(ip1) + + if len(parts) == 2 { + ip2, err := netip.ParseAddr(parts[1]) + if err != nil { + anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, parts[1]) + } else { + ip2anon := anonymizer.AnonymizeIP(ip2) + anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, ip2anon) + } + } else { + anonymizedIPs[i] = ip1anon.String() + } + } + return anonymizedIPs +} From bfc33a3f6f49d1d6b43c276d243a48f33c77bc09 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 2 Aug 2024 14:54:37 +0200 Subject: [PATCH 34/53] Move Bundle to before netbird down (#2377) This allows to get interface and route information added by the agent --- client/cmd/debug.go | 55 +++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index dbdce91ab..9abd2039d 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -138,17 +138,20 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) } - restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting) + stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting) initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{}) if err != nil { return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message()) } - if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { - return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) + if stateWasDown { + if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { + return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) + } + cmd.Println("Netbird up") + time.Sleep(time.Second * 10) } - cmd.Println("Netbird down") initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE if !initialLevelTrace { @@ -161,6 +164,11 @@ func runForDuration(cmd *cobra.Command, args []string) error { cmd.Println("Log level set to trace.") } + if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { + return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) + } + cmd.Println("Netbird down") + time.Sleep(1 * time.Second) if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { @@ -178,32 +186,11 @@ func runForDuration(cmd *cobra.Command, args []string) error { } cmd.Println("\nDuration completed") + cmd.Println("Creating debug bundle...") + headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd)) - if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { - return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) - } - cmd.Println("Netbird down") - - time.Sleep(1 * time.Second) - - if restoreUp { - if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { - return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) - } - cmd.Println("Netbird up") - } - - if !initialLevelTrace { - if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil { - return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message()) - } - cmd.Println("Log level restored to", initialLogLevel.GetLevel()) - } - - cmd.Println("Creating debug bundle...") - resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ Anonymize: anonymizeFlag, Status: statusOutput, @@ -213,6 +200,20 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } + if stateWasDown { + if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { + return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) + } + cmd.Println("Netbird down") + } + + if !initialLevelTrace { + if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil { + return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message()) + } + cmd.Println("Log level restored to", initialLogLevel.GetLevel()) + } + cmd.Println(resp.GetPath()) return nil From e6f7222034bb715d4983d02d2009192ca36d973e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 2 Aug 2024 18:07:57 +0200 Subject: [PATCH 35/53] Fix Windows file version (#2380) Systems that validates the binary version didn't like the build number as we set This fixes the versioning and will use a static build number --- .github/workflows/release.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1889b58e7..89eb2ded9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -80,13 +80,13 @@ jobs: - name: Install goversioninfo run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e - name: Generate windows syso 386 - run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_386.syso + run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_386.syso - name: Generate windows syso arm - run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm.syso + run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm.syso - name: Generate windows syso arm64 - run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm64.syso + run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso - name: Generate windows syso amd64 - run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_amd64.syso + run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso - name: Run GoReleaser uses: goreleaser/goreleaser-action@v4 @@ -170,7 +170,7 @@ jobs: - name: Install goversioninfo run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e - name: Generate windows syso amd64 - run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/ui/resources_windows_amd64.syso + run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso - name: Run GoReleaser uses: goreleaser/goreleaser-action@v4 From 727a4f075311279e2f9004f51a6759bfa4f94fd7 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Fri, 2 Aug 2024 18:20:13 +0200 Subject: [PATCH 36/53] Remove Codacy badge as it is broken (#2379) --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 5be1826b4..98dd18d37 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,6 @@ -
From 501fd93e479cc58d4a447e84437aa52023145f80 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 2 Aug 2024 18:43:00 +0200 Subject: [PATCH 37/53] Fix DNS resolution for routes on iOS (#2378) --- client/internal/dns/server.go | 6 +- client/internal/dns/server_test.go | 2 +- client/internal/dns/service_memory.go | 20 +++---- client/internal/dns/upstream_ios.go | 38 +++++++------ client/internal/routemanager/client.go | 8 ++- client/internal/routemanager/dynamic/route.go | 19 ++++++- .../routemanager/dynamic/route_generic.go | 13 +++++ .../routemanager/dynamic/route_ios.go | 55 +++++++++++++++++++ 8 files changed, 124 insertions(+), 37 deletions(-) create mode 100644 client/internal/routemanager/dynamic/route_generic.go create mode 100644 client/internal/routemanager/dynamic/route_ios.go diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 267c1ed80..a4651ebb5 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -94,7 +94,7 @@ func NewDefaultServer( var dnsService service if wgInterface.IsUserspaceBind() { - dnsService = newServiceViaMemory(wgInterface) + dnsService = NewServiceViaMemory(wgInterface) } else { dnsService = newServiceViaListener(wgInterface, addrPort) } @@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream( statusRecorder *peer.Status, ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) - ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true ds.addHostRootZone() @@ -130,7 +130,7 @@ func NewDefaultServerIos( iosDnsManager IosDnsManager, statusRecorder *peer.Status, ) *DefaultServer { - ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds.iosDnsManager = iosDnsManager return ds } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 6cbd9ea15..b9552bc17 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { hostManager := &mockHostConfigurator{} server := DefaultServer{ - service: newServiceViaMemory(&mocWGIface{}), + service: NewServiceViaMemory(&mocWGIface{}), localResolver: &localResolver{ registeredMap: make(registrationMap), }, diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 757cd962a..729b90cc0 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" ) -type serviceViaMemory struct { +type ServiceViaMemory struct { wgInterface WGIface dnsMux *dns.ServeMux runtimeIP string @@ -22,8 +22,8 @@ type serviceViaMemory struct { listenerFlagLock sync.Mutex } -func newServiceViaMemory(wgIface WGIface) *serviceViaMemory { - s := &serviceViaMemory{ +func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { + s := &ServiceViaMemory{ wgInterface: wgIface, dnsMux: dns.NewServeMux(), @@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory { return s } -func (s *serviceViaMemory) Listen() error { +func (s *ServiceViaMemory) Listen() error { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() @@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error { return nil } -func (s *serviceViaMemory) Stop() { +func (s *ServiceViaMemory) Stop() { s.listenerFlagLock.Lock() defer s.listenerFlagLock.Unlock() @@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() { s.listenerIsRunning = false } -func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) { +func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) { s.dnsMux.Handle(pattern, handler) } -func (s *serviceViaMemory) DeregisterMux(pattern string) { +func (s *ServiceViaMemory) DeregisterMux(pattern string) { s.dnsMux.HandleRemove(pattern) } -func (s *serviceViaMemory) RuntimePort() int { +func (s *ServiceViaMemory) RuntimePort() int { return s.runtimePort } -func (s *serviceViaMemory) RuntimeIP() string { +func (s *ServiceViaMemory) RuntimeIP() string { return s.runtimeIP } -func (s *serviceViaMemory) filterDNSTraffic() (string, error) { +func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { filter := s.wgInterface.GetFilter() if filter == nil { return "", fmt.Errorf("can't set DNS filter, filter not initialized") diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 0c01a013e..60ed79d87 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -4,6 +4,7 @@ package dns import ( "context" + "fmt" "net" "syscall" "time" @@ -17,9 +18,9 @@ import ( type upstreamResolverIOS struct { *upstreamResolverBase - lIP net.IP - lNet *net.IPNet - iIndex int + lIP net.IP + lNet *net.IPNet + interfaceName string } func newUpstreamResolver( @@ -32,17 +33,11 @@ func newUpstreamResolver( ) (*upstreamResolverIOS, error) { upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) - index, err := getInterfaceIndex(interfaceName) - if err != nil { - log.Debugf("unable to get interface index for %s: %s", interfaceName, err) - return nil, err - } - ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, lIP: ip, lNet: net, - iIndex: index, + interfaceName: interfaceName, } ios.upstreamClient = ios @@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * client := &dns.Client{} upstreamHost, _, err := net.SplitHostPort(upstream) if err != nil { - log.Errorf("error while parsing upstream host: %s", err) + return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) } timeout := upstreamTimeout @@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * upstreamIP := net.ParseIP(upstreamHost) if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { log.Debugf("using private client to query upstream: %s", upstream) - client = u.getClientPrivate(timeout) + client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) + if err != nil { + return nil, 0, fmt.Errorf("error while creating private client: %s", err) + } } // Cannot use client.ExchangeContext because it overwrites our Dialer return client.Exchange(r, upstream) } -// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface +// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // This method is needed for iOS -func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client { +func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { + index, err := getInterfaceIndex(interfaceName) + if err != nil { + log.Debugf("unable to get interface index for %s: %s", interfaceName, err) + return nil, err + } + dialer := &net.Dialer{ LocalAddr: &net.UDPAddr{ - IP: u.lIP, + IP: ip, Port: 0, // Let the OS pick a free port }, Timeout: dialTimeout, Control: func(network, address string, c syscall.RawConn) error { var operr error fn := func(s uintptr) { - operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex) + operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index) } if err := c.Control(fn); err != nil { @@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C client := &dns.Client{ Dialer: dialer, } - return client + return client, nil } func getInterfaceIndex(interfaceName string) (int, error) { diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 92c71b1e0..1566d10dd 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -65,7 +66,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration routePeersNotifiers: make(map[string]chan struct{}), routeUpdate: make(chan routesUpdate), peerStateUpdate: make(chan struct{}), - handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder), + handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface), } return client } @@ -383,9 +384,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } -func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler { +func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler { if rt.IsDynamic() { - return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder) + dns := nbdns.NewServiceViaMemory(wgInterface) + return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) } return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) } diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index e95710798..3296f3ddf 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) @@ -47,6 +48,8 @@ type Route struct { currentPeerKey string cancel context.CancelFunc statusRecorder *peer.Status + wgInterface *iface.WGIface + resolverAddr string } func NewRoute( @@ -55,6 +58,8 @@ func NewRoute( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, interval time.Duration, statusRecorder *peer.Status, + wgInterface *iface.WGIface, + resolverAddr string, ) *Route { return &Route{ route: rt, @@ -63,6 +68,8 @@ func NewRoute( interval: interval, dynamicDomains: domainMap{}, statusRecorder: statusRecorder, + wgInterface: wgInterface, + resolverAddr: resolverAddr, } } @@ -228,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) { wg.Add(1) go func(domain domain.Domain) { defer wg.Done() - ips, err := net.LookupIP(string(domain)) + + ips, err := r.getIPsFromResolver(domain) if err != nil { - results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)} - return + log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err) + ips, err = net.LookupIP(string(domain)) + if err != nil { + results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)} + return + } } + for _, ip := range ips { prefix, err := util.GetPrefixFromIP(ip) if err != nil { diff --git a/client/internal/routemanager/dynamic/route_generic.go b/client/internal/routemanager/dynamic/route_generic.go new file mode 100644 index 000000000..cf3d913a4 --- /dev/null +++ b/client/internal/routemanager/dynamic/route_generic.go @@ -0,0 +1,13 @@ +//go:build !ios + +package dynamic + +import ( + "net" + + "github.com/netbirdio/netbird/management/domain" +) + +func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { + return net.LookupIP(string(domain)) +} diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go new file mode 100644 index 000000000..67138222f --- /dev/null +++ b/client/internal/routemanager/dynamic/route_ios.go @@ -0,0 +1,55 @@ +//go:build ios + +package dynamic + +import ( + "fmt" + "net" + "time" + + "github.com/miekg/dns" + + nbdns "github.com/netbirdio/netbird/client/internal/dns" + + "github.com/netbirdio/netbird/management/domain" +) + +const dialTimeout = 10 * time.Second + +func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { + privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout) + if err != nil { + return nil, fmt.Errorf("error while creating private client: %s", err) + } + + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA) + + startTime := time.Now() + + response, _, err := privateClient.Exchange(msg, r.resolverAddr) + if err != nil { + return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err) + } + + if response.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode]) + } + + ips := make([]net.IP, 0) + + for _, answ := range response.Answer { + if aRecord, ok := answ.(*dns.A); ok { + ips = append(ips, aRecord.A) + } + if aaaaRecord, ok := answ.(*dns.AAAA); ok { + ips = append(ips, aaaaRecord.AAAA) + } + } + + if len(ips) == 0 { + return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString()) + } + + return ips, nil +} From 0371f529cae15c2429c8d18fb92a5cabaf77109f Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 2 Aug 2024 18:48:12 +0200 Subject: [PATCH 38/53] Add sonar badge (#2381) --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 98dd18d37..370445412 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@

+ + + From 059fc7c3a2182756ac8830353841b2e3747e22f8 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 3 Aug 2024 20:15:19 +0200 Subject: [PATCH 39/53] Use docker compose command (#2382) replace calls to docker-compose with docker compose --- .github/workflows/test-infrastructure-files.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index abdd18ceb..52b8ee3e2 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -151,10 +151,10 @@ jobs: - name: run docker compose up working-directory: infrastructure_files/artifacts run: | - docker-compose up -d + docker compose up -d sleep 5 - docker-compose ps - docker-compose logs --tail=20 + docker compose ps + docker compose logs --tail=20 - name: test running containers run: | @@ -207,7 +207,7 @@ jobs: - name: Postgres run cleanup run: | - docker-compose down --volumes --rmi all + docker compose down --volumes --rmi all rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env - name: run script with Zitadel CockroachDB From 6b930271fd87b1dd27f2074736194928a3b22161 Mon Sep 17 00:00:00 2001 From: Evgenii Date: Sun, 4 Aug 2024 21:13:08 +0100 Subject: [PATCH 40/53] change default config location on freebsd (#2388) --- client/cmd/root.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index 010ffb25a..db02ff5ea 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -93,12 +93,15 @@ func init() { oldDefaultConfigPathDir = "/etc/wiretrustee/" oldDefaultLogFileDir = "/var/log/wiretrustee/" - if runtime.GOOS == "windows" { + switch runtime.GOOS { + case "windows": defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" + case "freebsd": + defaultConfigPathDir = "/var/db/netbird/" } defaultConfigPath = defaultConfigPathDir + "config.json" From d56dfae9b883c57f150201bfaf913525f785a3fe Mon Sep 17 00:00:00 2001 From: Evgenii Date: Sun, 4 Aug 2024 21:31:43 +0100 Subject: [PATCH 41/53] Offer only Device Code Flow on FreeBSD (#2389) --- client/internal/auth/oauth.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 7467584a3..c9f10ca86 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -69,6 +69,11 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl return authenticateWithDeviceCodeFlow(ctx, config) } + // On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384) + if runtime.GOOS == "freebsd" { + return authenticateWithDeviceCodeFlow(ctx, config) + } + pkceFlow, err := authenticateWithPKCEFlow(ctx, config) if err != nil { // fallback to device code flow From 1802e51213df139f942527522e5cea828a99217e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 5 Aug 2024 11:03:14 +0200 Subject: [PATCH 42/53] Fix windows binary version (#2390) --- .github/workflows/release.yml | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 89eb2ded9..30f24e92e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -79,15 +79,8 @@ jobs: - name: Install goversioninfo run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e - - name: Generate windows syso 386 - run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_386.syso - - name: Generate windows syso arm - run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm.syso - - name: Generate windows syso arm64 - run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso - name: Generate windows syso amd64 - run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso - + run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso - name: Run GoReleaser uses: goreleaser/goreleaser-action@v4 with: @@ -170,7 +163,7 @@ jobs: - name: Install goversioninfo run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e - name: Generate windows syso amd64 - run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso + run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso - name: Run GoReleaser uses: goreleaser/goreleaser-action@v4 From 855fba8fac230b54f585479eb84b4f9500996001 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 6 Aug 2024 22:30:19 +0200 Subject: [PATCH 43/53] On iOS add error handling for getRouteselector (#2394) --- client/ios/NetBirdSDK/client.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index d80072c78..779c27a4d 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -271,7 +271,14 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } routesMap := engine.GetClientRoutesWithNetID() - routeSelector := engine.GetRouteManager().GetRouteSelector() + routeManager := engine.GetRouteManager() + if routeManager == nil { + return nil, fmt.Errorf("could not get route manager") + } + routeSelector := routeManager.GetRouteSelector() + if routeSelector == nil { + return nil, fmt.Errorf("could not get route selector") + } var routes []*selectRoute for id, rt := range routesMap { From 54d896846b15fd462414e3bb7b8e1b9f083741be Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 7 Aug 2024 10:22:12 +0200 Subject: [PATCH 44/53] Skip network map check if not regular user (#2402) when getting all peers we don't need to calculate network map when not a regular user --- management/server/peer.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 998a9e53b..05fb11236 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -65,12 +65,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + regularUser := !user.HasAdminPower() && !user.IsServiceUser + + if regularUser && account.Settings.RegularUsersViewBlocked { return peers, nil } for _, peer := range account.Peers { - if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID { + if regularUser && user.Id != peer.UserID { // only display peers that belong to the current user if the current user is not an admin continue } @@ -79,6 +81,10 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID peersMap[peer.ID] = p } + if !regularUser { + return peers, nil + } + // fetch all the peers that have access to the user's peers for _, peer := range peers { aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) From ac0d5ff9f3252bdc31514aa2e16ad59916b13a0e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 7 Aug 2024 10:52:31 +0200 Subject: [PATCH 45/53] [management] Improve mgmt sync performance (#2363) --- client/cmd/testutil_test.go | 8 +- client/internal/engine_test.go | 7 +- client/server/server_test.go | 7 +- management/client/client_test.go | 9 +- management/cmd/management.go | 2 +- management/server/account.go | 88 ++++- management/server/account_test.go | 21 +- management/server/dns.go | 144 ++++--- management/server/dns_test.go | 158 +++++++- management/server/grpcserver.go | 81 ++-- management/server/http/peers_handler.go | 21 +- management/server/management_proto_test.go | 7 +- management/server/management_test.go | 10 +- management/server/nameserver_test.go | 7 +- management/server/peer.go | 46 ++- management/server/peer/peer.go | 3 +- management/server/peer/peer_test.go | 31 ++ management/server/peer_test.go | 362 ++++++++++++++++++ management/server/policy.go | 23 +- management/server/route_test.go | 7 +- .../telemetry/accountmanager_metrics.go | 69 ++++ management/server/telemetry/app_metrics.go | 68 ++-- 22 files changed, 1005 insertions(+), 174 deletions(-) create mode 100644 management/server/peer/peer_test.go create mode 100644 management/server/telemetry/accountmanager_metrics.go diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index b4c2791d8..984aa6df7 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -11,6 +11,7 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" @@ -71,6 +72,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) { t.Helper() + lis, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) @@ -88,7 +90,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste return nil, nil } iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 6c6f79d07..e0f85d211 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -36,6 +36,7 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" signal "github.com/netbirdio/netbird/signal/client" "github.com/netbirdio/netbird/signal/proto" @@ -1069,7 +1070,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) return nil, "", err } ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index b19e4615f..6a3de774c 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -19,6 +19,7 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" ) @@ -120,7 +121,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics) if err != nil { return nil, "", err } diff --git a/management/client/client_test.go b/management/client/client_test.go index 2774f2b59..cec3e77f2 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -9,7 +9,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/client/system" @@ -71,7 +74,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index b87c386c6..a0176c548 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -190,7 +190,7 @@ var ( return fmt.Errorf("failed to initialize integrated peer validator: %v", err) } accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, - dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator) + dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 5d3ee6dc1..e99e0e7f3 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -18,6 +18,8 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" + "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -37,6 +39,7 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" ) @@ -170,6 +173,8 @@ type DefaultAccountManager struct { userDeleteFromIDPEnabled bool integratedPeerValidator integrated_validator.IntegratedValidator + + metrics telemetry.AppMetrics } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -401,8 +406,16 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group { return a.Groups[groupID] } -// GetPeerNetworkMap returns a group by ID if exists, nil otherwise -func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { +// GetPeerNetworkMap returns the networkmap for the given peer ID. +func (a *Account) GetPeerNetworkMap( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeersMap map[string]struct{}, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ @@ -438,7 +451,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin if dnsManagementStatus { var zones []nbdns.CustomZone - peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain) + if peersCustomZone.Domain != "" { zones = append(zones, peersCustomZone) } @@ -446,7 +459,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) } - return &NetworkMap{ + nm := &NetworkMap{ Peers: peersToConnect, Network: a.Network.Copy(), Routes: routesUpdate, @@ -454,6 +467,60 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin OfflinePeers: expiredPeers, FirewallRules: firewallRules, } + + if metrics != nil { + objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + } + + return nm +} + +func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { + var merr *multierror.Error + + if dnsDomain == "" { + log.WithContext(ctx).Error("no dns domain is set, returning empty zone") + return nbdns.CustomZone{} + } + + customZone := nbdns.CustomZone{ + Domain: dns.Fqdn(dnsDomain), + Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), + } + + domainSuffix := "." + dnsDomain + + var sb strings.Builder + for _, peer := range a.Peers { + if peer.DNSLabel == "" { + merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) + continue + } + + sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) + sb.WriteString(peer.DNSLabel) + sb.WriteString(domainSuffix) + + customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ + Name: sb.String(), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: defaultTTL, + RData: peer.IP.String(), + }) + + sb.Reset() + } + + go func() { + if merr != nil { + log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) + } + }() + + return customZone } // GetExpiredPeers returns peers that have been expired @@ -871,10 +938,18 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { } // BuildManager creates a new DefaultAccountManager with a provided Store -func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, - singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, +func BuildManager( + ctx context.Context, + store Store, + peersUpdateManager *PeersUpdateManager, + idpManager idp.Manager, + singleAccountModeDomain string, + dnsDomain string, + eventStore activity.Store, + geo *geolocation.Geolocation, userDeleteFromIDPEnabled bool, integratedPeerValidator integrated_validator.IntegratedValidator, + metrics telemetry.AppMetrics, ) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ Store: store, @@ -889,6 +964,7 @@ func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpd peerLoginExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, + metrics: metrics, } allAccounts := store.GetAllAccounts(ctx) // enable single account mode only if configured by user and number of existing accounts is not grater than 1 diff --git a/management/server/account_test.go b/management/server/account_test.go index 45b4fbd6f..03b5fa83e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" ) @@ -410,7 +411,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { validatedPeers[p] = struct{}{} } - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers) + customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -2293,7 +2295,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { }) } -func createManager(t *testing.T) (*DefaultAccountManager, error) { +type TB interface { + Cleanup(func()) + Helper() + TempDir() string +} + +func createManager(t TB) (*DefaultAccountManager, error) { t.Helper() store, err := createStore(t) @@ -2302,7 +2310,12 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { } eventStore := &activity.InMemoryEventStore{} - manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + return nil, err + } + + manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) if err != nil { return nil, err } @@ -2310,7 +2323,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { return manager, nil } -func createStore(t *testing.T) (Store, error) { +func createStore(t TB) (Store, error) { t.Helper() dataDir := t.TempDir() store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) diff --git a/management/server/dns.go b/management/server/dns.go index 08732ad78..1d156c90a 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -4,8 +4,8 @@ import ( "context" "fmt" "strconv" + "sync" - "github.com/miekg/dns" log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" @@ -17,6 +17,50 @@ import ( const defaultTTL = 300 +// DNSConfigCache is a thread-safe cache for DNS configuration components +type DNSConfigCache struct { + CustomZones sync.Map + NameServerGroups sync.Map +} + +// GetCustomZone retrieves a cached custom zone +func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) { + if c == nil { + return nil, false + } + if value, ok := c.CustomZones.Load(key); ok { + return value.(*proto.CustomZone), true + } + return nil, false +} + +// SetCustomZone stores a custom zone in the cache +func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) { + if c == nil { + return + } + c.CustomZones.Store(key, value) +} + +// GetNameServerGroup retrieves a cached name server group +func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { + if c == nil { + return nil, false + } + if value, ok := c.NameServerGroups.Load(key); ok { + return value.(*proto.NameServerGroup), true + } + return nil, false +} + +// SetNameServerGroup stores a name server group in the cache +func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) { + if c == nil { + return + } + c.NameServerGroups.Store(key, value) +} + type lookupMap map[string]struct{} // DNSSettings defines dns settings at the account level @@ -113,69 +157,73 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return nil } -func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig { - protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable} +// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache +func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { + protoUpdate := &proto.DNSConfig{ + ServiceEnable: update.ServiceEnable, + CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), + NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), + } for _, zone := range update.CustomZones { - protoZone := &proto.CustomZone{Domain: zone.Domain} - for _, record := range zone.Records { - protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ - Name: record.Name, - Type: int64(record.Type), - Class: record.Class, - TTL: int64(record.TTL), - RData: record.RData, - }) + cacheKey := zone.Domain + if cachedZone, exists := cache.GetCustomZone(cacheKey); exists { + protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone) + } else { + protoZone := convertToProtoCustomZone(zone) + cache.SetCustomZone(cacheKey, protoZone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) } - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) } for _, nsGroup := range update.NameServerGroups { - protoGroup := &proto.NameServerGroup{ - Primary: nsGroup.Primary, - Domains: nsGroup.Domains, - SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, + cacheKey := nsGroup.ID + if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists { + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) + } else { + protoGroup := convertToProtoNameServerGroup(nsGroup) + cache.SetNameServerGroup(cacheKey, protoGroup) + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) } - for _, ns := range nsGroup.NameServers { - protoNS := &proto.NameServer{ - IP: ns.IP.String(), - Port: int64(ns.Port), - NSType: int64(ns.NSType), - } - protoGroup.NameServers = append(protoGroup.NameServers, protoNS) - } - protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) } return protoUpdate } -func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone { - if dnsDomain == "" { - log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone") - return nbdns.CustomZone{} +// Helper function to convert nbdns.CustomZone to proto.CustomZone +func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { + protoZone := &proto.CustomZone{ + Domain: zone.Domain, + Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), } - - customZone := nbdns.CustomZone{ - Domain: dns.Fqdn(dnsDomain), - } - - for _, peer := range account.Peers { - if peer.DNSLabel == "" { - log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name) - continue - } - - customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ - Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain), - Type: int(dns.TypeA), - Class: nbdns.DefaultClass, - TTL: defaultTTL, - RData: peer.IP.String(), + for _, record := range zone.Records { + protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ + Name: record.Name, + Type: int64(record.Type), + Class: record.Class, + TTL: int64(record.TTL), + RData: record.RData, }) } + return protoZone +} - return customZone +// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup +func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { + protoGroup := &proto.NameServerGroup{ + Primary: nsGroup.Primary, + Domains: nsGroup.Domains, + SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, + NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), + } + for _, ns := range nsGroup.NameServers { + protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ + IP: ns.IP.String(), + Port: int64(ns.Port), + NSType: int64(ns.NSType), + }) + } + return protoGroup } func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index c6758036f..e033c1a21 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -2,9 +2,14 @@ package server import ( "context" + "fmt" "net/netip" + "reflect" "testing" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/dns" @@ -195,7 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics) } func createDNSStore(t *testing.T) (Store, error) { @@ -320,3 +329,150 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return am.Store.GetAccount(context.Background(), account.Id) } + +func generateTestData(size int) nbdns.Config { + config := nbdns.Config{ + ServiceEnable: true, + CustomZones: make([]nbdns.CustomZone, size), + NameServerGroups: make([]*nbdns.NameServerGroup, size), + } + + for i := 0; i < size; i++ { + config.CustomZones[i] = nbdns.CustomZone{ + Domain: fmt.Sprintf("domain%d.com", i), + Records: []nbdns.SimpleRecord{ + { + Name: fmt.Sprintf("record%d", i), + Type: 1, + Class: "IN", + TTL: 3600, + RData: "192.168.1.1", + }, + }, + } + + config.NameServerGroups[i] = &nbdns.NameServerGroup{ + ID: fmt.Sprintf("group%d", i), + Primary: i == 0, + Domains: []string{fmt.Sprintf("domain%d.com", i)}, + SearchDomainsEnabled: true, + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + Port: 53, + NSType: 1, + }, + }, + } + } + + return config +} + +func BenchmarkToProtocolDNSConfig(b *testing.B) { + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + testData := generateTestData(size) + + b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) { + cache := &DNSConfigCache{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + toProtocolDNSConfig(testData, cache) + } + }) + + b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache := &DNSConfigCache{} + toProtocolDNSConfig(testData, cache) + } + }) + } +} + +func TestToProtocolDNSConfigWithCache(t *testing.T) { + var cache DNSConfigCache + + // Create two different configs + config1 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.com", + Records: []nbdns.SimpleRecord{ + {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group1", + Name: "Group 1", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.8.8"), Port: 53}, + }, + }, + }, + } + + config2 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.org", + Records: []nbdns.SimpleRecord{ + {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group2", + Name: "Group 2", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.4.4"), Port: 53}, + }, + }, + }, + } + + // First run with config1 + result1 := toProtocolDNSConfig(config1, &cache) + + // Second run with config2 + result2 := toProtocolDNSConfig(config2, &cache) + + // Third run with config1 again + result3 := toProtocolDNSConfig(config1, &cache) + + // Verify that result1 and result3 are identical + if !reflect.DeepEqual(result1, result3) { + t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3) + } + + // Verify that result2 is different from result1 and result3 + if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) { + t.Errorf("Results should be different for different inputs") + } + + // Verify that the cache contains elements from both configs + if _, exists := cache.GetCustomZone("example.com"); !exists { + t.Errorf("Cache should contain custom zone for example.com") + } + + if _, exists := cache.GetCustomZone("example.org"); !exists { + t.Errorf("Cache should contain custom zone for example.org") + } + + if _, exists := cache.GetNameServerGroup("group1"); !exists { + t.Errorf("Cache should contain name server group 'group1'") + } + + if _, exists := cache.GetNameServerGroup("group2"); !exists { + t.Errorf("Cache should contain name server group 'group2'") + } +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index f71a45d99..7738abe5e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -533,53 +533,46 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe } } -func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { - remotePeers := []*proto.RemotePeerConfig{} - for _, rPeer := range peers { - fqdn := rPeer.FQDN(dnsName) - remotePeers = append(remotePeers, &proto.RemotePeerConfig{ - WgPubKey: rPeer.Key, - AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, - SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, - Fqdn: fqdn, - }) - } - return remotePeers -} - -func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse { - wtConfig := toWiretrusteeConfig(config, turnCredentials) - - pConfig := toPeerConfig(peer, networkMap.Network, dnsName) - - remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName) - - routesUpdate := toProtocolRoutes(networkMap.Routes) - - dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig) - - offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName) - - firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) - - return &proto.SyncResponse{ - WiretrusteeConfig: wtConfig, - PeerConfig: pConfig, - RemotePeers: remotePeers, - RemotePeersIsEmpty: len(remotePeers) == 0, +func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse { + response := &proto.SyncResponse{ + WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName), NetworkMap: &proto.NetworkMap{ - Serial: networkMap.Network.CurrentSerial(), - PeerConfig: pConfig, - RemotePeers: remotePeers, - OfflinePeers: offlinePeers, - RemotePeersIsEmpty: len(remotePeers) == 0, - Routes: routesUpdate, - DNSConfig: dnsUpdate, - FirewallRules: firewallRules, - FirewallRulesIsEmpty: len(firewallRules) == 0, + Serial: networkMap.Network.CurrentSerial(), + Routes: toProtocolRoutes(networkMap.Routes), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache), }, Checks: toProtocolChecks(ctx, checks), } + + response.NetworkMap.PeerConfig = response.PeerConfig + + allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) + allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName) + response.RemotePeers = allPeers + response.NetworkMap.RemotePeers = allPeers + response.RemotePeersIsEmpty = len(allPeers) == 0 + response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty + + response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) + + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + response.NetworkMap.FirewallRules = firewallRules + response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 + + return response +} + +func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { + for _, rPeer := range peers { + dst = append(dst, &proto.RemotePeerConfig{ + WgPubKey: rPeer.Key, + AllowedIps: []string{rPeer.IP.String() + "/32"}, + SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, + Fqdn: rPeer.FQDN(dnsName), + }) + } + return dst } // IsHealthy indicates whether the service is healthy @@ -597,7 +590,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p } else { turnCredentials = nil } - plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks) + plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 1fb18669c..913d424d1 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -71,7 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee return } - netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers) + customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain()) + netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) _, valid := validPeers[peer.ID] @@ -115,7 +116,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, util.WriteError(ctx, fmt.Errorf("internal error"), w) return } - netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers) + + customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain()) + netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) _, valid := validPeers[peer.ID] @@ -194,9 +197,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { } groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID) - - respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) + respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } validPeersMap, err := h.accountManager.GetValidatedPeers(account) @@ -210,16 +211,6 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, respBody) } -func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) { - validatedPeersMap, err := h.accountManager.GetValidatedPeers(account) - if err != nil { - return 0, err - } - - netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) - return len(netMap.Peers) + len(netMap.OfflinePeers), nil -} - func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { for _, peer := range respBody { _, ok := approvedPeersMap[peer.Id] diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 2c9d43948..fe1e36d47 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" ) @@ -419,8 +420,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccoun ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false, MocIntegratedValidator{}) + eventStore, nil, false, MocIntegratedValidator{}, metrics) + if err != nil { return nil, nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 092567607..62e7f5a05 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" ) @@ -541,8 +542,13 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false, MocIntegratedValidator{}) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + log.Fatalf("failed creating metrics: %v", err) + } + + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index dd7935fee..5f8545243 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" ) const ( @@ -762,7 +763,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } func createNSStore(t *testing.T) (Store, error) { diff --git a/management/server/peer.go b/management/server/peer.go index 05fb11236..7afe6ee0d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "strings" + "sync" "time" "github.com/rs/xid" @@ -322,7 +323,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin if err != nil { return nil, err } - return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil } // GetPeerNetwork returns the Network for a given peer @@ -535,7 +537,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } postureChecks := am.getPeerPostureChecks(account, peer) - networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap) + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil } @@ -591,7 +594,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } postureChecks = am.getPeerPostureChecks(account, peer) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil } // LoginPeer logs in or registers a peer. @@ -738,7 +742,8 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is } postureChecks = am.getPeerPostureChecks(account, peer) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil } func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error { @@ -914,22 +919,45 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { + start := time.Now() + defer func() { + if am.metrics != nil { + am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start)) + } + }() + peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { - log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err) + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) return } + + var wg sync.WaitGroup + semaphore := make(chan struct{}, 10) + + dnsCache := &DNSConfigCache{} + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + for _, peer := range peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) continue } - postureChecks := am.getPeerPostureChecks(account, peer) - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap) - update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) + wg.Add(1) + semaphore <- struct{}{} + go func(p *nbpeer.Peer) { + defer wg.Done() + defer func() { <-semaphore }() + + postureChecks := am.getPeerPostureChecks(account, p) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) + update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) + am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) + }(peer) } + + wg.Wait() } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 4f808a79e..3d9ba18e9 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -1,7 +1,6 @@ package peer import ( - "fmt" "net" "net/netip" "slices" @@ -241,7 +240,7 @@ func (p *Peer) FQDN(dnsDomain string) string { if dnsDomain == "" { return "" } - return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain) + return p.DNSLabel + "." + dnsDomain } // EventMeta returns activity event meta related to the peer diff --git a/management/server/peer/peer_test.go b/management/server/peer/peer_test.go new file mode 100644 index 000000000..7b94f68c6 --- /dev/null +++ b/management/server/peer/peer_test.go @@ -0,0 +1,31 @@ +package peer + +import ( + "fmt" + "testing" +) + +// FQDNOld is the original implementation for benchmarking purposes +func (p *Peer) FQDNOld(dnsDomain string) string { + if dnsDomain == "" { + return "" + } + return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain) +} + +func BenchmarkFQDN(b *testing.B) { + p := &Peer{DNSLabel: "test-peer"} + dnsDomain := "example.com" + + b.Run("Old", func(b *testing.B) { + for i := 0; i < b.N; i++ { + p.FQDNOld(dnsDomain) + } + }) + + b.Run("New", func(b *testing.B) { + for i := 0; i < b.N; i++ { + p.FQDN(dnsDomain) + } + }) +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 407877296..918436515 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -2,15 +2,26 @@ package server import ( "context" + "fmt" + "io" + "net" + "net/netip" + "os" "testing" "time" "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/management/proto" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + nbroute "github.com/netbirdio/netbird/route" ) func TestPeer_LoginExpired(t *testing.T) { @@ -633,3 +644,354 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } } + +func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) { + b.Helper() + + manager, err := createManager(b) + if err != nil { + return nil, "", "", err + } + + accountID := "test_account" + adminUser := "account_creator" + regularUser := "regular_user" + + account := newAccountWithId(context.Background(), accountID, adminUser, "") + account.Users[regularUser] = &User{ + Id: regularUser, + Role: UserRoleUser, + } + + // Create peers + for i := 0; i < peers; i++ { + peerKey, _ := wgtypes.GeneratePrivateKey() + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + DNSLabel: fmt.Sprintf("peer-%d", i), + Key: peerKey.PublicKey().String(), + IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + Status: &nbpeer.PeerStatus{}, + UserID: regularUser, + } + account.Peers[peer.ID] = peer + } + + // Create groups and policies + account.Policies = make([]*Policy, 0, groups) + for i := 0; i < groups; i++ { + groupID := fmt.Sprintf("group-%d", i) + group := &nbgroup.Group{ + ID: groupID, + Name: fmt.Sprintf("Group %d", i), + } + for j := 0; j < peers/groups; j++ { + peerIndex := i*(peers/groups) + j + group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) + } + account.Groups[groupID] = group + + // Create a policy for this group + policy := &Policy{ + ID: fmt.Sprintf("policy-%d", i), + Name: fmt.Sprintf("Policy for Group %d", i), + Enabled: true, + Rules: []*PolicyRule{ + { + ID: fmt.Sprintf("rule-%d", i), + Name: fmt.Sprintf("Rule for Group %d", i), + Enabled: true, + Sources: []string{groupID}, + Destinations: []string{groupID}, + Bidirectional: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + }, + }, + } + account.Policies = append(account.Policies, policy) + } + + account.PostureChecks = []*posture.Checks{ + { + ID: "PostureChecksAll", + Name: "All", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", + }, + }, + }, + } + + err = manager.Store.SaveAccount(context.Background(), account) + if err != nil { + return nil, "", "", err + } + + return manager, accountID, regularUser, nil +} + +func BenchmarkGetPeers(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + }{ + {"Small", 50, 5}, + {"Medium", 500, 10}, + {"Large", 5000, 20}, + {"Small single", 50, 1}, + {"Medium single", 500, 1}, + {"Large 5", 5000, 5}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := manager.GetPeers(context.Background(), accountID, userID) + if err != nil { + b.Fatalf("GetPeers failed: %v", err) + } + } + }) + } +} + +func BenchmarkUpdateAccountPeers(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + }{ + {"Small", 50, 5}, + {"Medium", 500, 10}, + {"Large", 5000, 20}, + {"Small single", 50, 1}, + {"Medium single", 500, 1}, + {"Large 5", 5000, 5}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + + ctx := context.Background() + + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + + peerChannels := make(map[string]chan *UpdateMessage) + + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + + for i := 0; i < b.N; i++ { + manager.updateAccountPeers(ctx, account) + } + + duration := time.Since(start) + b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op") + b.ReportMetric(0, "ns/op") + }) + } +} + +func TestToSyncResponse(t *testing.T) { + _, ipnet, err := net.ParseCIDR("192.168.1.0/24") + if err != nil { + t.Fatal(err) + } + domainList, err := domain.FromStringList([]string{"example.com"}) + if err != nil { + t.Fatal(err) + } + + config := &Config{ + Signal: &Host{ + Proto: "https", + URI: "signal.uri", + Username: "", + Password: "", + }, + Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, + TURNConfig: &TURNConfig{ + Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, + }, + } + peer := &nbpeer.Peer{ + IP: net.ParseIP("192.168.1.1"), + SSHEnabled: true, + Key: "peer-key", + DNSLabel: "peer1", + SSHKey: "peer1-ssh-key", + } + turnCredentials := &TURNCredentials{ + Username: "turn-user", + Password: "turn-pass", + } + networkMap := &NetworkMap{ + Network: &Network{Net: *ipnet, Serial: 1000}, + Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, + OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, + Routes: []*nbroute.Route{ + { + ID: "route1", + Network: netip.MustParsePrefix("10.0.0.0/24"), + Domains: domainList, + KeepRoute: true, + NetID: "route1", + Peer: "peer1", + NetworkType: 1, + Masquerade: true, + Metric: 9999, + Enabled: true, + }, + }, + DNSConfig: nbdns.Config{ + ServiceEnable: true, + NameServerGroups: []*nbdns.NameServerGroup{ + { + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + { + ID: "ns1", + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Groups: []string{"group1"}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + }, + CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, + }, + FirewallRules: []*FirewallRule{ + {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, + }, + } + dnsName := "example.com" + checks := []*posture.Checks{ + { + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}}, + }, + }, + }, + } + dnsCache := &DNSConfigCache{} + + response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache) + + assert.NotNil(t, response) + // assert peer config + assert.Equal(t, "192.168.1.1/24", response.PeerConfig.Address) + assert.Equal(t, "peer1.example.com", response.PeerConfig.Fqdn) + assert.Equal(t, true, response.PeerConfig.SshConfig.SshEnabled) + // assert wiretrustee config + assert.Equal(t, "signal.uri", response.WiretrusteeConfig.Signal.Uri) + assert.Equal(t, proto.HostConfig_HTTPS, response.WiretrusteeConfig.Signal.GetProtocol()) + assert.Equal(t, "stun.uri", response.WiretrusteeConfig.Stuns[0].Uri) + assert.Equal(t, "turn.uri", response.WiretrusteeConfig.Turns[0].HostConfig.GetUri()) + assert.Equal(t, "turn-user", response.WiretrusteeConfig.Turns[0].User) + assert.Equal(t, "turn-pass", response.WiretrusteeConfig.Turns[0].Password) + // assert RemotePeers + assert.Equal(t, 1, len(response.RemotePeers)) + assert.Equal(t, "192.168.1.2/32", response.RemotePeers[0].AllowedIps[0]) + assert.Equal(t, "peer2-key", response.RemotePeers[0].WgPubKey) + assert.Equal(t, "peer2.example.com", response.RemotePeers[0].GetFqdn()) + assert.Equal(t, false, response.RemotePeers[0].GetSshConfig().GetSshEnabled()) + assert.Equal(t, []byte("peer2-ssh-key"), response.RemotePeers[0].GetSshConfig().GetSshPubKey()) + // assert network map + assert.Equal(t, uint64(1000), response.NetworkMap.Serial) + assert.Equal(t, "192.168.1.1/24", response.NetworkMap.PeerConfig.Address) + assert.Equal(t, "peer1.example.com", response.NetworkMap.PeerConfig.Fqdn) + assert.Equal(t, true, response.NetworkMap.PeerConfig.SshConfig.SshEnabled) + // assert network map RemotePeers + assert.Equal(t, 1, len(response.NetworkMap.RemotePeers)) + assert.Equal(t, "192.168.1.2/32", response.NetworkMap.RemotePeers[0].AllowedIps[0]) + assert.Equal(t, "peer2-key", response.NetworkMap.RemotePeers[0].WgPubKey) + assert.Equal(t, "peer2.example.com", response.NetworkMap.RemotePeers[0].GetFqdn()) + assert.Equal(t, []byte("peer2-ssh-key"), response.NetworkMap.RemotePeers[0].GetSshConfig().GetSshPubKey()) + // assert network map OfflinePeers + assert.Equal(t, 1, len(response.NetworkMap.OfflinePeers)) + assert.Equal(t, "192.168.1.3/32", response.NetworkMap.OfflinePeers[0].AllowedIps[0]) + assert.Equal(t, "peer3-key", response.NetworkMap.OfflinePeers[0].WgPubKey) + assert.Equal(t, "peer3.example.com", response.NetworkMap.OfflinePeers[0].GetFqdn()) + assert.Equal(t, []byte("peer3-ssh-key"), response.NetworkMap.OfflinePeers[0].GetSshConfig().GetSshPubKey()) + // assert network map Routes + assert.Equal(t, 1, len(response.NetworkMap.Routes)) + assert.Equal(t, "10.0.0.0/24", response.NetworkMap.Routes[0].Network) + assert.Equal(t, "route1", response.NetworkMap.Routes[0].ID) + assert.Equal(t, "peer1", response.NetworkMap.Routes[0].Peer) + assert.Equal(t, "example.com", response.NetworkMap.Routes[0].Domains[0]) + assert.Equal(t, true, response.NetworkMap.Routes[0].KeepRoute) + assert.Equal(t, true, response.NetworkMap.Routes[0].Masquerade) + assert.Equal(t, int64(9999), response.NetworkMap.Routes[0].Metric) + assert.Equal(t, int64(1), response.NetworkMap.Routes[0].NetworkType) + assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID) + // assert network map DNSConfig + assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable) + assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones)) + assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups)) + // assert network map DNSConfig.CustomZones + assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Domain) + assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones[0].Records)) + assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Name) + assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Type) + assert.Equal(t, "IN", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Class) + assert.Equal(t, int64(60), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].TTL) + assert.Equal(t, "100.64.0.1", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData) + // assert network map DNSConfig.NameServerGroups + assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].Primary) + assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].SearchDomainsEnabled) + assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.NameServerGroups[0].Domains[0]) + assert.Equal(t, "8.8.8.8", response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetIP()) + assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetNSType()) + assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort()) + // assert network map Firewall + assert.Equal(t, 1, len(response.NetworkMap.FirewallRules)) + assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP) + assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction) + assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action) + assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol) + assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port) + // assert posture checks + assert.Equal(t, 1, len(response.Checks)) + assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0]) +} diff --git a/management/server/policy.go b/management/server/policy.go index 30614ed2d..aaf9b6e72 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -213,7 +213,6 @@ type FirewallRule struct { // // This function returns the list of peers and firewall rules that are applicable to a given peer. func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { - generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) for _, policy := range a.Policies { if !policy.Enabled { @@ -225,8 +224,8 @@ func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, continue } - sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap) + sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) if rule.Bidirectional { if peerInSources { @@ -290,8 +289,8 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, fr.PeerIP = "0.0.0.0" } - ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) + - fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")) + ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") if _, ok := rulesExists[ruleID]; ok { continue } @@ -491,23 +490,23 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { // // Important: Posture checks are applicable only to source group peers, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { +func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) for _, g := range groups { - group, ok := account.Groups[g] + group, ok := a.Groups[g] if !ok { continue } for _, p := range group.Peers { - peer, ok := account.Peers[p] + peer, ok := a.Peers[p] if !ok || peer == nil { continue } // validate the peer based on policy posture checks applied - isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) if !isValid { continue } @@ -535,7 +534,7 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture } for _, postureChecksID := range sourcePostureChecksID { - postureChecks := getPostureChecks(a, postureChecksID) + postureChecks := a.getPostureChecks(postureChecksID) if postureChecks == nil { continue } @@ -553,8 +552,8 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture return true } -func getPostureChecks(account *Account, postureChecksID string) *posture.Checks { - for _, postureChecks := range account.PostureChecks { +func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { + for _, postureChecks := range a.PostureChecks { if postureChecks.ID == postureChecksID { return postureChecks } diff --git a/management/server/route_test.go b/management/server/route_test.go index 8b168a79f..47dc4d078 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" ) @@ -1233,7 +1234,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } func createRouterStore(t *testing.T) (Store, error) { diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go new file mode 100644 index 000000000..e4bb4e3c3 --- /dev/null +++ b/management/server/telemetry/accountmanager_metrics.go @@ -0,0 +1,69 @@ +package telemetry + +import ( + "context" + "time" + + "go.opentelemetry.io/otel/metric" +) + +// AccountManagerMetrics represents all metrics related to the AccountManager +type AccountManagerMetrics struct { + ctx context.Context + updateAccountPeersDurationMs metric.Float64Histogram + getPeerNetworkMapDurationMs metric.Float64Histogram + networkMapObjectCount metric.Int64Histogram +} + +// NewAccountManagerMetrics creates an instance of AccountManagerMetrics +func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*AccountManagerMetrics, error) { + updateAccountPeersDurationMs, err := meter.Float64Histogram("management.account.update.account.peers.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithExplicitBucketBoundaries( + 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000, + )) + if err != nil { + return nil, err + } + + getPeerNetworkMapDurationMs, err := meter.Float64Histogram("management.account.get.peer.network.map.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithExplicitBucketBoundaries( + 0.1, 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, + )) + if err != nil { + return nil, err + } + + networkMapObjectCount, err := meter.Int64Histogram("management.account.network.map.object.count", + metric.WithUnit("objects"), + metric.WithExplicitBucketBoundaries( + 50, 100, 200, 500, 1000, 2500, 5000, 10000, + )) + if err != nil { + return nil, err + } + + return &AccountManagerMetrics{ + ctx: ctx, + getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, + updateAccountPeersDurationMs: updateAccountPeersDurationMs, + networkMapObjectCount: networkMapObjectCount, + }, nil + +} + +// CountUpdateAccountPeersDuration counts the duration of updating account peers +func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) { + metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6) +} + +// CountGetPeerNetworkMapDuration counts the duration of getting the peer network map +func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration time.Duration) { + metrics.getPeerNetworkMapDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6) +} + +// CountNetworkMapObjects counts the number of network map objects +func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { + metrics.networkMapObjectCount.Record(metrics.ctx, count) +} diff --git a/management/server/telemetry/app_metrics.go b/management/server/telemetry/app_metrics.go index d88e18d8a..09deb8127 100644 --- a/management/server/telemetry/app_metrics.go +++ b/management/server/telemetry/app_metrics.go @@ -20,14 +20,15 @@ const defaultEndpoint = "/metrics" // MockAppMetrics mocks the AppMetrics interface type MockAppMetrics struct { - GetMeterFunc func() metric2.Meter - CloseFunc func() error - ExposeFunc func(ctx context.Context, port int, endpoint string) error - IDPMetricsFunc func() *IDPMetrics - HTTPMiddlewareFunc func() *HTTPMiddleware - GRPCMetricsFunc func() *GRPCMetrics - StoreMetricsFunc func() *StoreMetrics - UpdateChannelMetricsFunc func() *UpdateChannelMetrics + GetMeterFunc func() metric2.Meter + CloseFunc func() error + ExposeFunc func(ctx context.Context, port int, endpoint string) error + IDPMetricsFunc func() *IDPMetrics + HTTPMiddlewareFunc func() *HTTPMiddleware + GRPCMetricsFunc func() *GRPCMetrics + StoreMetricsFunc func() *StoreMetrics + UpdateChannelMetricsFunc func() *UpdateChannelMetrics + AddAccountManagerMetricsFunc func() *AccountManagerMetrics } // GetMeter mocks the GetMeter function of the AppMetrics interface @@ -94,6 +95,14 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics { return nil } +// AccountManagerMetrics mocks the MockAppMetrics function of the AccountManagerMetrics interface +func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics { + if mock.AddAccountManagerMetricsFunc != nil { + return mock.AddAccountManagerMetricsFunc() + } + return nil +} + // AppMetrics is metrics interface type AppMetrics interface { GetMeter() metric2.Meter @@ -104,19 +113,21 @@ type AppMetrics interface { GRPCMetrics() *GRPCMetrics StoreMetrics() *StoreMetrics UpdateChannelMetrics() *UpdateChannelMetrics + AccountManagerMetrics() *AccountManagerMetrics } // defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/ type defaultAppMetrics struct { // Meter can be used by different application parts to create counters and measure things - Meter metric2.Meter - listener net.Listener - ctx context.Context - idpMetrics *IDPMetrics - httpMiddleware *HTTPMiddleware - grpcMetrics *GRPCMetrics - storeMetrics *StoreMetrics - updateChannelMetrics *UpdateChannelMetrics + Meter metric2.Meter + listener net.Listener + ctx context.Context + idpMetrics *IDPMetrics + httpMiddleware *HTTPMiddleware + grpcMetrics *GRPCMetrics + storeMetrics *StoreMetrics + updateChannelMetrics *UpdateChannelMetrics + accountManagerMetrics *AccountManagerMetrics } // IDPMetrics returns metrics for the idp package @@ -144,6 +155,11 @@ func (appMetrics *defaultAppMetrics) UpdateChannelMetrics() *UpdateChannelMetric return appMetrics.updateChannelMetrics } +// AccountManagerMetrics returns metrics for the account manager +func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetrics { + return appMetrics.accountManagerMetrics +} + // Close stop application metrics HTTP handler and closes listener. func (appMetrics *defaultAppMetrics) Close() error { if appMetrics.listener == nil { @@ -220,13 +236,19 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) { return nil, err } + accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter) + if err != nil { + return nil, err + } + return &defaultAppMetrics{ - Meter: meter, - ctx: ctx, - idpMetrics: idpMetrics, - httpMiddleware: middleware, - grpcMetrics: grpcMetrics, - storeMetrics: storeMetrics, - updateChannelMetrics: updateChannelMetrics, + Meter: meter, + ctx: ctx, + idpMetrics: idpMetrics, + httpMiddleware: middleware, + grpcMetrics: grpcMetrics, + storeMetrics: storeMetrics, + updateChannelMetrics: updateChannelMetrics, + accountManagerMetrics: accountManagerMetrics, }, nil } From bcce1bf18474e68744d81a6b05b8a32e901ae771 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 7 Aug 2024 15:40:43 +0300 Subject: [PATCH 46/53] Update dependencies and switch systray library (#2309) * Update dependencies and switch systray library This commit updates the project's dependencies and switches from the 'getlantern/systray' library to the 'fyne.io/systray' library. It also removes some unused dependencies, improving the maintainability and performance of the project. This change in the system tray library is an upgrade which offers more extensive features and better support. * Remove legacy_appindicator tag from .goreleaser_ui.yaml --- .goreleaser_ui.yaml | 2 -- client/ui/client_ui.go | 2 +- go.mod | 10 +--------- go.sum | 21 ++------------------- 4 files changed, 4 insertions(+), 31 deletions(-) diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index b13085e86..fd92b5328 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -11,8 +11,6 @@ builds: - amd64 ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - tags: - - legacy_appindicator mod_timestamp: '{{ .CommitTimestamp }}' - id: netbird-ui-windows diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 58004dd4a..d046bab5f 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -22,8 +22,8 @@ import ( "fyne.io/fyne/v2/app" "fyne.io/fyne/v2/dialog" "fyne.io/fyne/v2/widget" + "fyne.io/systray" "github.com/cenkalti/backoff/v4" - "github.com/getlantern/systray" log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" diff --git a/go.mod b/go.mod index 7f242c203..25e726769 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( require ( fyne.io/fyne/v2 v2.1.4 + fyne.io/systray v1.11.0 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/c-robinson/iplib v1.0.3 github.com/cilium/ebpf v0.15.0 @@ -38,7 +39,6 @@ require ( github.com/creack/pty v1.1.18 github.com/eko/gocache/v3 v3.1.1 github.com/fsnotify/fsnotify v1.6.0 - github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 github.com/godbus/dbus/v5 v5.1.0 github.com/golang/mock v1.6.0 @@ -120,19 +120,12 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect - github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect - github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect - github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect - github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect - github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect - github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect - github.com/go-stack/stack v1.8.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect @@ -164,7 +157,6 @@ require ( github.com/nxadm/tail v1.4.8 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect - github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect github.com/pegasus-kv/thrift v0.13.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/mdns v0.0.12 // indirect diff --git a/go.sum b/go.sum index 1251a6fd7..2e2cadd43 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= fyne.io/fyne/v2 v2.1.4 h1:bt1+28++kAzRzPB0GM2EuSV4cnl8rXNX4cjfd8G06Rc= fyne.io/fyne/v2 v2.1.4/go.mod h1:p+E/Dh+wPW8JwR2DVcsZ9iXgR9ZKde80+Y+40Is54AQ= +fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg= +fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= @@ -111,18 +113,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= -github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 h1:NRUJuo3v3WGC/g5YiyF790gut6oQr5f3FBI88Wv0dx4= -github.com/getlantern/context v0.0.0-20190109183933-c447772a6520/go.mod h1:L+mq6/vvYHKjCX2oez0CgEAJmbq1fbb/oNJIWQkBybY= -github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 h1:6uJ+sZ/e03gkbqZ0kUG6mfKoqDb4XMAzMIwlajq19So= -github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7/go.mod h1:l+xpFBrCtDLpK9qNjxs+cHU6+BAdlBaxHqikB6Lku3A= -github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 h1:guBYzEaLz0Vfc/jv0czrr2z7qyzTOGC9hiQ0VC+hKjk= -github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7/go.mod h1:zx/1xUUeYPy3Pcmet8OSXLbF47l+3y6hIPpyLWoR9oc= -github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 h1:micT5vkcr9tOVk1FiH8SWKID8ultN44Z+yzd2y/Vyb0= -github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7/go.mod h1:dD3CgOrwlzca8ed61CsZouQS5h5jIzkK9ZWrTcf0s+o= -github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 h1:XYzSdCbkzOC0FDNrgJqGRo8PCMFOBFL9py72DRs7bmc= -github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55/go.mod h1:6mmzY2kW1TOOrVy+r41Za2MxXM+hhqTtY3oBKd2AgFA= -github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f h1:wrYrQttPS8FHIRSlsrcuKazukx/xqO/PpLZzZXsF+EA= -github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do= @@ -151,8 +141,6 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7 github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= -github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= @@ -337,8 +325,6 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513- github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= -github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= @@ -368,8 +354,6 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs= github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= -github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c h1:rp5dCmg/yLR3mgFuSOe4oEnDDmGLROTvMragMUXpTQw= -github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c/go.mod h1:X07ZCGwUbLaax7L0S3Tw4hpejzu63ZrrQiUe6W0hcy0= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY8d4= @@ -609,7 +593,6 @@ golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= From 0911163146d2894680394435c48babb1a1529421 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 8 Aug 2024 18:01:38 +0300 Subject: [PATCH 47/53] Add batch delete for groups and users (#2370) * Refactor user deletion logic and introduce batch delete * Prevent self-deletion for users * Add delete multiple groups * Refactor group deletion with validation * Fix tests * Add bulk delete functions for Users and Groups in account manager interface and mocks * Add tests for DeleteGroups method in group management * Add tests for DeleteUsers method in users management --- management/server/account.go | 2 + management/server/group.go | 239 +++++++++++------- management/server/group_test.go | 145 ++++++++++- management/server/mock_server/account_mock.go | 18 ++ management/server/user.go | 146 ++++++++--- management/server/user_test.go | 151 +++++++++++ 6 files changed, 575 insertions(+), 126 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index e99e0e7f3..ca53ebad0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -68,6 +68,7 @@ type AccountManager interface { SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error + DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) @@ -101,6 +102,7 @@ type AccountManager interface { SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error + DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error diff --git a/management/server/group.go b/management/server/group.go index 37a6fc305..49720f347 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -2,8 +2,12 @@ package server import ( "context" + "errors" "fmt" + "slices" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -243,7 +247,7 @@ func difference(a, b []string) []string { return diff } -// DeleteGroup object of the peers +// DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) defer unlock() @@ -253,96 +257,14 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use return err } - g, ok := account.Groups[groupID] + group, ok := account.Groups[groupID] if !ok { return nil } - // disable a deleting integration group if the initiator is not an admin service user - if g.Issued == nbgroup.GroupIssuedIntegration { - executingUser := account.Users[userId] - if executingUser == nil { - return status.Errorf(status.NotFound, "user not found") - } - if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { - return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") - } + if err = validateDeleteGroup(account, group, userId); err != nil { + return err } - - // check route links - for _, r := range account.Routes { - for _, g := range r.Groups { - if g == groupID { - return &GroupLinkError{"route", string(r.NetID)} - } - } - for _, g := range r.PeerGroups { - if g == groupID { - return &GroupLinkError{"route", string(r.NetID)} - } - } - } - - // check DNS links - for _, dns := range account.NameServerGroups { - for _, g := range dns.Groups { - if g == groupID { - return &GroupLinkError{"name server groups", dns.Name} - } - } - } - - // check ACL links - for _, policy := range account.Policies { - for _, rule := range policy.Rules { - for _, src := range rule.Sources { - if src == groupID { - return &GroupLinkError{"policy", policy.Name} - } - } - - for _, dst := range rule.Destinations { - if dst == groupID { - return &GroupLinkError{"policy", policy.Name} - } - } - } - } - - // check setup key links - for _, setupKey := range account.SetupKeys { - for _, grp := range setupKey.AutoGroups { - if grp == groupID { - return &GroupLinkError{"setup key", setupKey.Name} - } - } - } - - // check user links - for _, user := range account.Users { - for _, grp := range user.AutoGroups { - if grp == groupID { - return &GroupLinkError{"user", user.Id} - } - } - } - - // check DisabledManagementGroups - for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups { - if disabledMgmGrp == groupID { - return &GroupLinkError{"disabled DNS management groups", g.Name} - } - } - - // check integrated peer validator groups - if account.Settings.Extra != nil { - for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups { - if groupID == integratedPeerValidatorGroups { - return &GroupLinkError{"integrated validator", g.Name} - } - } - } - delete(account.Groups, groupID) account.Network.IncSerial() @@ -350,13 +272,57 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use return err } - am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) + am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) am.updateAccountPeers(ctx, account) return nil } +// DeleteGroups deletes groups from an account. +// Note: This function does not acquire the global lock. +// It is the caller's responsibility to ensure proper locking is in place before invoking this method. +// +// If an error occurs while deleting a group, the function skips it and continues deleting other groups. +// Errors are collected and returned at the end. +func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { + account, err := am.Store.GetAccount(ctx, accountId) + if err != nil { + return err + } + + var allErrors error + + deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) + for _, groupID := range groupIDs { + group, ok := account.Groups[groupID] + if !ok { + continue + } + + if err := validateDeleteGroup(account, group, userId); err != nil { + allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) + continue + } + + delete(account.Groups, groupID) + deletedGroups = append(deletedGroups, group) + } + + account.Network.IncSerial() + if err = am.Store.SaveAccount(ctx, account); err != nil { + return err + } + + for _, g := range deletedGroups { + am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) + } + + am.updateAccountPeers(ctx, account) + + return allErrors +} + // ListGroups objects of the peers func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -440,3 +406,102 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return nil } + +func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { + // disable a deleting integration group if the initiator is not an admin service user + if group.Issued == nbgroup.GroupIssuedIntegration { + executingUser := account.Users[userID] + if executingUser == nil { + return status.Errorf(status.NotFound, "user not found") + } + if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { + return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") + } + } + + if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { + return &GroupLinkError{"route", string(linkedRoute.NetID)} + } + + if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { + return &GroupLinkError{"name server groups", linkedDns.Name} + } + + if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { + return &GroupLinkError{"policy", linkedPolicy.Name} + } + + if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { + return &GroupLinkError{"setup key", linkedSetupKey.Name} + } + + if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { + return &GroupLinkError{"user", linkedUser.Id} + } + + if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { + return &GroupLinkError{"disabled DNS management groups", group.Name} + } + + if account.Settings.Extra != nil { + if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { + return &GroupLinkError{"integrated validator", group.Name} + } + } + + return nil +} + +// isGroupLinkedToRoute checks if a group is linked to any route in the account. +func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { + for _, r := range routes { + if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { + return true, r + } + } + return false, nil +} + +// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. +func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { + for _, policy := range policies { + for _, rule := range policy.Rules { + if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { + return true, policy + } + } + } + return false, nil +} + +// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. +func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { + for _, dns := range nameServerGroups { + for _, g := range dns.Groups { + if g == groupID { + return true, dns + } + } + } + return false, nil +} + +// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. +func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { + for _, setupKey := range setupKeys { + if slices.Contains(setupKey.AutoGroups, groupID) { + return true, setupKey + } + } + return false, nil +} + +// isGroupLinkedToUser checks if a group is linked to any user in the account. +func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { + for _, user := range users { + if slices.Contains(user.AutoGroups, groupID) { + return true, user + } + } + return false, nil +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 373d72964..89b68ad6c 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -3,12 +3,14 @@ package server import ( "context" "errors" + "fmt" "testing" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" + "github.com/stretchr/testify/assert" ) const ( @@ -21,7 +23,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestGroupAccount(am) + _, account, err := initTestGroupAccount(am) if err != nil { t.Error("failed to init testing account") } @@ -56,7 +58,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestGroupAccount(am) + _, account, err := initTestGroupAccount(am) if err != nil { t.Error("failed to init testing account") } @@ -132,7 +134,136 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) { } } -func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { +func TestDefaultAccountManager_DeleteGroups(t *testing.T) { + am, err := createManager(t) + assert.NoError(t, err, "Failed to create account manager") + + manager, account, err := initTestGroupAccount(am) + assert.NoError(t, err, "Failed to init testing account") + + groups := make([]*nbgroup.Group, 10) + for i := 0; i < 10; i++ { + groups[i] = &nbgroup.Group{ + ID: fmt.Sprintf("group-%d", i+1), + AccountID: account.Id, + Name: fmt.Sprintf("group-%d", i+1), + Issued: nbgroup.GroupIssuedAPI, + } + } + + err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups) + assert.NoError(t, err, "Failed to save test groups") + + testCases := []struct { + name string + groupIDs []string + expectedReasons []string + expectedDeleted []string + expectedNotDeleted []string + }{ + { + name: "route", + groupIDs: []string{"grp-for-route"}, + expectedReasons: []string{"route"}, + }, + { + name: "route with peer groups", + groupIDs: []string{"grp-for-route2"}, + expectedReasons: []string{"route"}, + }, + { + name: "name server groups", + groupIDs: []string{"grp-for-name-server-grp"}, + expectedReasons: []string{"name server groups"}, + }, + { + name: "policy", + groupIDs: []string{"grp-for-policies"}, + expectedReasons: []string{"policy"}, + }, + { + name: "setup keys", + groupIDs: []string{"grp-for-keys"}, + expectedReasons: []string{"setup key"}, + }, + { + name: "users", + groupIDs: []string{"grp-for-users"}, + expectedReasons: []string{"user"}, + }, + { + name: "integration", + groupIDs: []string{"grp-for-integration"}, + expectedReasons: []string{"only service users with admin power can delete integration group"}, + }, + { + name: "successfully delete multiple groups", + groupIDs: []string{"group-1", "group-2"}, + expectedDeleted: []string{"group-1", "group-2"}, + }, + { + name: "delete non-existent group", + groupIDs: []string{"non-existent-group"}, + expectedDeleted: []string{"non-existent-group"}, + }, + { + name: "delete multiple groups with mixed results", + groupIDs: []string{"group-3", "grp-for-policies", "group-4", "grp-for-users"}, + expectedReasons: []string{"policy", "user"}, + expectedDeleted: []string{"group-3", "group-4"}, + expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"}, + }, + { + name: "delete groups with multiple errors", + groupIDs: []string{"grp-for-policies", "grp-for-users"}, + expectedReasons: []string{"policy", "user"}, + expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, tc.groupIDs) + if len(tc.expectedReasons) > 0 { + assert.Error(t, err) + var foundExpectedErrors int + + wrappedErr, ok := err.(interface{ Unwrap() []error }) + assert.Equal(t, ok, true) + + for _, e := range wrappedErr.Unwrap() { + var sErr *status.Error + if errors.As(e, &sErr) { + assert.Contains(t, tc.expectedReasons, sErr.Message, "unexpected error message") + foundExpectedErrors++ + } + + var gErr *GroupLinkError + if errors.As(e, &gErr) { + assert.Contains(t, tc.expectedReasons, gErr.Resource, "unexpected error resource") + foundExpectedErrors++ + } + } + assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found") + } else { + assert.NoError(t, err) + } + + for _, groupID := range tc.expectedDeleted { + _, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID) + assert.Error(t, err, "group should have been deleted: %s", groupID) + } + + for _, groupID := range tc.expectedNotDeleted { + group, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID) + assert.NoError(t, err, "group should not have been deleted: %s", groupID) + assert.NotNil(t, group, "group should exist: %s", groupID) + } + }) + } +} + +func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) { accountID := "testingAcc" domain := "example.com" @@ -236,7 +367,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { err := am.Store.SaveAccount(context.Background(), account) if err != nil { - return nil, err + return nil, nil, err } _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) @@ -247,5 +378,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) - return am.Store.GetAccount(context.Background(), account.Id) + acc, err := am.Store.GetAccount(context.Background(), account.Id) + if err != nil { + return nil, nil, err + } + return am, acc, nil } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a66bdee2b..495325252 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -42,6 +42,7 @@ type MockAccountManager struct { SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error + DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error @@ -67,6 +68,7 @@ type MockAccountManager struct { SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) @@ -326,6 +328,14 @@ func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented") } +// DeleteGroups mock implementation of DeleteGroups from server.AccountManager interface +func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { + if am.DeleteGroupsFunc != nil { + return am.DeleteGroupsFunc(ctx, accountId, userId, groupIDs) + } + return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") +} + // ListGroups mock implementation of ListGroups from server.AccountManager interface func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { if am.ListGroupsFunc != nil { @@ -528,6 +538,14 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented") } +// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface +func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error { + if am.DeleteRegularUsersFunc != nil { + return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs) + } + return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented") +} + func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { if am.InviteUserFunc != nil { return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID) diff --git a/management/server/user.go b/management/server/user.go index b8afcda3a..8fd0a1627 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -2,6 +2,7 @@ package server import ( "context" + "errors" "fmt" "strings" "time" @@ -472,51 +473,18 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init } func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { - tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) - if err != nil { - log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) - return err - } - - if !isNil(am.idpManager) { - // Delete if the user already exists in the IdP.Necessary in cases where a user account - // was created where a user account was provisioned but the user did not sign in - _, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id}) - if err == nil { - err = am.deleteUserFromIDP(ctx, targetUserID, account.Id) - if err != nil { - log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID) - return err - } - } else { - log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) - } - } - - err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) + meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { return err } - u, err := account.FindUser(targetUserID) - if err != nil { - log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err) - } - - var tuCreatedAt time.Time - if u != nil { - tuCreatedAt = u.CreatedAt - } - delete(account.Users, targetUserID) err = am.Store.SaveAccount(ctx, account) if err != nil { return err } - meta := map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt} am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) - am.updateAccountPeers(ctx, account) return nil @@ -1190,6 +1158,116 @@ func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context return "", "", fmt.Errorf("user info not found for user: %s", targetId) } +// DeleteRegularUsers deletes regular users from an account. +// Note: This function does not acquire the global lock. +// It is the caller's responsibility to ensure proper locking is in place before invoking this method. +// +// If an error occurs while deleting the user, the function skips it and continues deleting other users. +// Errors are collected and returned at the end. +func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error { + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + + executingUser := account.Users[initiatorUserID] + if executingUser == nil { + return status.Errorf(status.NotFound, "user not found") + } + if !executingUser.HasAdminPower() { + return status.Errorf(status.PermissionDenied, "only users with admin power can delete users") + } + + var allErrors error + + deletedUsersMeta := make(map[string]map[string]any) + for _, targetUserID := range targetUserIDs { + if initiatorUserID == targetUserID { + allErrors = errors.Join(allErrors, errors.New("self deletion is not allowed")) + continue + } + + targetUser := account.Users[targetUserID] + if targetUser == nil { + allErrors = errors.Join(allErrors, fmt.Errorf("target user: %s not found", targetUserID)) + continue + } + + if targetUser.Role == UserRoleOwner { + allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID)) + continue + } + + // disable deleting integration user if the initiator is not admin service user + if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser { + allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user")) + continue + } + + meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + if err != nil { + allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err)) + continue + } + + delete(account.Users, targetUserID) + deletedUsersMeta[targetUserID] = meta + } + + err = am.Store.SaveAccount(ctx, account) + if err != nil { + return fmt.Errorf("failed to delete users: %w", err) + } + + am.updateAccountPeers(ctx, account) + + for targetUserID, meta := range deletedUsersMeta { + am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) + } + + return allErrors +} + +func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) { + tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) + if err != nil { + log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) + return nil, err + } + + if !isNil(am.idpManager) { + // Delete if the user already exists in the IdP. Necessary in cases where a user account + // was created where a user account was provisioned but the user did not sign in + _, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id}) + if err == nil { + err = am.deleteUserFromIDP(ctx, targetUserID, account.Id) + if err != nil { + log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID) + return nil, err + } + } else { + log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) + } + } + + err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) + if err != nil { + return nil, err + } + + u, err := account.FindUser(targetUserID) + if err != nil { + log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err) + } + + var tuCreatedAt time.Time + if u != nil { + tuCreatedAt = u.CreatedAt + } + + return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index 99d2792df..272060276 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -662,6 +662,157 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } +func TestUser_DeleteUser_RegularUsers(t *testing.T) { + store := newStore(t) + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") + + targetId := "user2" + account.Users[targetId] = &User{ + Id: targetId, + IsServiceUser: true, + ServiceUserName: "user2username", + } + targetId = "user3" + account.Users[targetId] = &User{ + Id: targetId, + IsServiceUser: false, + Issued: UserIssuedAPI, + } + targetId = "user4" + account.Users[targetId] = &User{ + Id: targetId, + IsServiceUser: false, + Issued: UserIssuedIntegration, + } + + targetId = "user5" + account.Users[targetId] = &User{ + Id: targetId, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleOwner, + } + account.Users["user6"] = &User{ + Id: "user6", + IsServiceUser: false, + Issued: UserIssuedAPI, + } + account.Users["user7"] = &User{ + Id: "user7", + IsServiceUser: false, + Issued: UserIssuedAPI, + } + account.Users["user8"] = &User{ + Id: "user8", + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleAdmin, + } + account.Users["user9"] = &User{ + Id: "user9", + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleAdmin, + } + + err := store.SaveAccount(context.Background(), account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + integratedPeerValidator: MocIntegratedValidator{}, + } + + testCases := []struct { + name string + userIDs []string + expectedReasons []string + expectedDeleted []string + expectedNotDeleted []string + }{ + { + name: "Delete service user successfully ", + userIDs: []string{"user2"}, + expectedDeleted: []string{"user2"}, + }, + { + name: "Delete regular user successfully", + userIDs: []string{"user3"}, + expectedDeleted: []string{"user3"}, + }, + { + name: "Delete integration regular user permission denied", + userIDs: []string{"user4"}, + expectedReasons: []string{"only integration service user can delete this user"}, + expectedNotDeleted: []string{"user4"}, + }, + { + name: "Delete user with owner role should return permission denied", + userIDs: []string{"user5"}, + expectedReasons: []string{"unable to delete a user: user5 with owner role"}, + expectedNotDeleted: []string{"user5"}, + }, + { + name: "Delete multiple users with mixed results", + userIDs: []string{"user5", "user5", "user6", "user7"}, + expectedReasons: []string{"only integration service user can delete this user", "unable to delete a user: user5 with owner role"}, + expectedDeleted: []string{"user6", "user7"}, + expectedNotDeleted: []string{"user4", "user5"}, + }, + { + name: "Delete non-existent user", + userIDs: []string{"non-existent-user"}, + expectedReasons: []string{"target user: non-existent-user not found"}, + expectedNotDeleted: []string{}, + }, + { + name: "Delete multiple regular users successfully", + userIDs: []string{"user8", "user9"}, + expectedDeleted: []string{"user8", "user9"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs) + if len(tc.expectedReasons) > 0 { + assert.Error(t, err) + var foundExpectedErrors int + + wrappedErr, ok := err.(interface{ Unwrap() []error }) + assert.Equal(t, ok, true) + + for _, e := range wrappedErr.Unwrap() { + assert.Contains(t, tc.expectedReasons, e.Error(), "unexpected error message") + foundExpectedErrors++ + } + + assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found") + } else { + assert.NoError(t, err) + } + + acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + assert.NoError(t, err) + + for _, id := range tc.expectedDeleted { + _, exists := acc.Users[id] + assert.False(t, exists, "user should have been deleted: %s", id) + } + + for _, id := range tc.expectedNotDeleted { + user, exists := acc.Users[id] + assert.True(t, exists, "user should not have been deleted: %s", id) + assert.NotNil(t, user, "user should exist: %s", id) + } + }) + } +} + func TestDefaultAccountManager_GetUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) From 18cef8280ad7dc0eebf38a57aad7e6a3988a8825 Mon Sep 17 00:00:00 2001 From: David Merris Date: Fri, 9 Aug 2024 11:32:09 -0400 Subject: [PATCH 48/53] [client] Allow setup keys to be provided in a file (#2337) Adds a flag and a bit of logic to allow a setup key to be passed in using a file. The flag should be exclusive with the standard --setup-key flag. --- client/cmd/login.go | 7 +++++++ client/cmd/root.go | 8 ++++++++ client/cmd/up.go | 7 +++++++ client/cmd/up_daemon_test.go | 31 +++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+) diff --git a/client/cmd/login.go b/client/cmd/login.go index 14c973d91..512fbb081 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -39,6 +39,13 @@ var loginCmd = &cobra.Command{ ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) } + if setupKeyPath != "" && setupKey == "" { + setupKey, err = getSetupKeyFromFile(setupKeyPath) + if err != nil { + return err + } + } + // workaround to run without service if logFile == "console" { err = handleRebrand(cmd) diff --git a/client/cmd/root.go b/client/cmd/root.go index db02ff5ea..b6d6694ee 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -56,6 +56,7 @@ var ( managementURL string adminURL string setupKey string + setupKeyPath string hostName string preSharedKey string natExternalIPs []string @@ -128,6 +129,8 @@ func init() { rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") + rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.") + rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") @@ -253,6 +256,11 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{ Clock: backoff.SystemClock, } +func getSetupKeyFromFile(setupKeyPath string) (string, error) { + data, err := os.ReadFile(setupKeyPath) + return string(data), err +} + func handleRebrand(cmd *cobra.Command) error { var err error if logFile == defaultLogFile { diff --git a/client/cmd/up.go b/client/cmd/up.go index f69e9eb27..0eaf7bc0d 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -73,6 +73,13 @@ func upFunc(cmd *cobra.Command, args []string) error { ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) } + if setupKeyPath != "" && setupKey == "" { + setupKey, err = getSetupKeyFromFile(setupKeyPath) + if err != nil { + return err + } + } + if foregroundMode { return runInForegroundMode(ctx, cmd) } diff --git a/client/cmd/up_daemon_test.go b/client/cmd/up_daemon_test.go index 0295d2b21..daf8d0628 100644 --- a/client/cmd/up_daemon_test.go +++ b/client/cmd/up_daemon_test.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "os" "testing" "time" @@ -40,6 +41,36 @@ func TestUpDaemon(t *testing.T) { return } + // Test the setup-key-file flag. + tempFile, err := os.CreateTemp("", "setup-key") + if err != nil { + t.Errorf("could not create temp file, got error %v", err) + return + } + defer os.Remove(tempFile.Name()) + if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil { + t.Errorf("could not write to temp file, got error %v", err) + return + } + if err := tempFile.Close(); err != nil { + t.Errorf("unable to close file, got error %v", err) + } + rootCmd.SetArgs([]string{ + "login", + "--daemon-addr", "tcp://" + cliAddr, + "--setup-key-file", tempFile.Name(), + "--log-file", "", + }) + if err := rootCmd.Execute(); err != nil { + t.Errorf("expected no error while running up command, got %v", err) + return + } + time.Sleep(time.Second * 3) + if status, err := state.Status(); err != nil && status != internal.StatusIdle { + t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err) + return + } + rootCmd.SetArgs([]string{ "up", "--daemon-addr", "tcp://" + cliAddr, From 12f9d12a11ba4dab74dc6af5a2bab8c647986759 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 9 Aug 2024 19:17:28 +0200 Subject: [PATCH 49/53] [misc] Update bug-issue-report.md to include netbird debug cmd (#2413) --- .github/ISSUE_TEMPLATE/bug-issue-report.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/ISSUE_TEMPLATE/bug-issue-report.md b/.github/ISSUE_TEMPLATE/bug-issue-report.md index 0aa6cd77e..8971972a2 100644 --- a/.github/ISSUE_TEMPLATE/bug-issue-report.md +++ b/.github/ISSUE_TEMPLATE/bug-issue-report.md @@ -35,6 +35,11 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan If applicable, add the `netbird status -d' command output. +**Do you face any client issues on desktop?** + +Please provide the file created by `netbird debug for 1m -AS`. +We advise reviewing the anonymized files for any remaining PII. + **Screenshots** If applicable, add screenshots to help explain your problem. From af1b42e538e74fe01b3f4398f95b4521d8ac3298 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 9 Aug 2024 20:38:58 +0200 Subject: [PATCH 50/53] [client] Parse data from setup key (#2411) refactor functions and variable assignment --- client/cmd/login.go | 12 +++++------- client/cmd/root.go | 12 +++++++++++- client/cmd/up.go | 21 ++++++++++++--------- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 512fbb081..c7dd0fda1 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -39,11 +39,9 @@ var loginCmd = &cobra.Command{ ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) } - if setupKeyPath != "" && setupKey == "" { - setupKey, err = getSetupKeyFromFile(setupKeyPath) - if err != nil { - return err - } + providedSetupKey, err := getSetupKey() + if err != nil { + return err } // workaround to run without service @@ -69,7 +67,7 @@ var loginCmd = &cobra.Command{ config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) - err = foregroundLogin(ctx, cmd, config, setupKey) + err = foregroundLogin(ctx, cmd, config, providedSetupKey) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -88,7 +86,7 @@ var loginCmd = &cobra.Command{ client := proto.NewDaemonServiceClient(conn) loginRequest := proto.LoginRequest{ - SetupKey: setupKey, + SetupKey: providedSetupKey, ManagementUrl: managementURL, IsLinuxDesktopClient: isLinuxRunningDesktop(), Hostname: hostName, diff --git a/client/cmd/root.go b/client/cmd/root.go index b6d6694ee..8dae6e273 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -256,9 +256,19 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{ Clock: backoff.SystemClock, } +func getSetupKey() (string, error) { + if setupKeyPath != "" && setupKey == "" { + return getSetupKeyFromFile(setupKeyPath) + } + return setupKey, nil +} + func getSetupKeyFromFile(setupKeyPath string) (string, error) { data, err := os.ReadFile(setupKeyPath) - return string(data), err + if err != nil { + return "", fmt.Errorf("failed to read setup key file: %v", err) + } + return strings.TrimSpace(string(data)), nil } func handleRebrand(cmd *cobra.Command) error { diff --git a/client/cmd/up.go b/client/cmd/up.go index 0eaf7bc0d..2ed6e41d2 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -73,13 +73,6 @@ func upFunc(cmd *cobra.Command, args []string) error { ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) } - if setupKeyPath != "" && setupKey == "" { - setupKey, err = getSetupKeyFromFile(setupKeyPath) - if err != nil { - return err - } - } - if foregroundMode { return runInForegroundMode(ctx, cmd) } @@ -154,6 +147,11 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { ic.DNSRouteInterval = &dnsRouteInterval } + providedSetupKey, err := getSetupKey() + if err != nil { + return err + } + config, err := internal.UpdateOrCreateConfig(ic) if err != nil { return fmt.Errorf("get config file: %v", err) @@ -161,7 +159,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) - err = foregroundLogin(ctx, cmd, config, setupKey) + err = foregroundLogin(ctx, cmd, config, providedSetupKey) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -206,8 +204,13 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { return nil } + providedSetupKey, err := getSetupKey() + if err != nil { + return err + } + loginRequest := proto.LoginRequest{ - SetupKey: setupKey, + SetupKey: providedSetupKey, ManagementUrl: managementURL, AdminURL: adminURL, NatExternalIPs: natExternalIPs, From 15eb752a7d574b33506cce9d90f4ff0b3700b827 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sun, 11 Aug 2024 15:01:04 +0200 Subject: [PATCH 51/53] [misc] Update bug-issue-report.md to include anon flag (#2412) --- .github/ISSUE_TEMPLATE/bug-issue-report.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug-issue-report.md b/.github/ISSUE_TEMPLATE/bug-issue-report.md index 8971972a2..789c61974 100644 --- a/.github/ISSUE_TEMPLATE/bug-issue-report.md +++ b/.github/ISSUE_TEMPLATE/bug-issue-report.md @@ -31,9 +31,9 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan `netbird version` -**NetBird status -d output:** +**NetBird status -dA output:** -If applicable, add the `netbird status -d' command output. +If applicable, add the `netbird status -dA' command output. **Do you face any client issues on desktop?** From 539480a713c621917fd1de348ada8cc0ffe41e8f Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 12 Aug 2024 13:48:05 +0300 Subject: [PATCH 52/53] [management] Prevent removal of All group from peers during user groups propagation (#2410) * Prevent removal of "All" group from peers * Prevent adding "All" group to users and setup keys * Refactor setup key group validation --- management/server/account.go | 2 +- management/server/setupkey.go | 23 +++++++++++++++--- management/server/setupkey_test.go | 39 ++++++++++++++++++++++++++---- management/server/user.go | 6 ++++- 4 files changed, 59 insertions(+), 11 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index ca53ebad0..972272746 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -922,7 +922,7 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { for _, gid := range groups { group, ok := a.Groups[gid] - if !ok { + if !ok || group.Name == "All" { continue } update := make([]string, 0, len(group.Peers)) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 8ef91755c..859f1b0b9 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -223,10 +223,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, err } - for _, group := range autoGroups { - if _, ok := account.Groups[group]; !ok { - return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group) - } + if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil { + return nil, err } setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) @@ -279,6 +277,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.Errorf(status.NotFound, "setup key not found") } + if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil { + return nil, err + } + // only auto groups, revoked status, and name can be updated for now newKey := oldKey.Copy() newKey.Name = keyToSave.Name @@ -399,3 +401,16 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return foundKey, nil } + +func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { + for _, group := range autoGroups { + g, ok := account.Groups[group] + if !ok { + return status.Errorf(status.NotFound, "group %s doesn't exist", group) + } + if g.Name == "All" { + return status.Errorf(status.InvalidArgument, "can't add All group to the setup key") + } + } + return nil +} diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 034f4e2d6..aa5075b02 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -26,10 +26,17 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group_1", - Name: "group_name_1", - Peers: []string{}, + err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "group_1", + Name: "group_name_1", + Peers: []string{}, + }, + { + ID: "group_2", + Name: "group_name_2", + Peers: []string{}, + }, }) if err != nil { t.Fatal(err) @@ -70,6 +77,19 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { assert.NotEmpty(t, ev.Meta["key"]) assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, key.Id, ev.TargetID) + + groupAll, err := account.GetGroupAll() + assert.NoError(t, err) + + // saving setup key with All group assigned to auto groups should return error + autoGroups = append(autoGroups, groupAll.ID) + _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + Id: key.Id, + Name: newKeyName, + Revoked: revoked, + AutoGroups: autoGroups, + }, userID) + assert.Error(t, err, "should not save setup key with All group assigned in auto groups") } func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { @@ -102,6 +122,9 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } + groupAll, err := account.GetGroupAll() + assert.NoError(t, err) + type testCase struct { name string @@ -134,8 +157,14 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { expectedGroups: []string{"FAKE"}, expectedFailure: true, } + testCase3 := testCase{ + name: "Create Setup Key should fail because of All group", + expectedKeyName: "my-test-key", + expectedGroups: []string{groupAll.ID}, + expectedFailure: true, + } - for _, tCase := range []testCase{testCase1, testCase2} { + for _, tCase := range []testCase{testCase1, testCase2, testCase3} { t.Run(tCase.name, func(t *testing.T) { key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) diff --git a/management/server/user.go b/management/server/user.go index 8fd0a1627..727bc5c6b 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -944,10 +944,14 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) } for _, newGroupID := range update.AutoGroups { - if _, ok := account.Groups[newGroupID]; !ok { + group, ok := account.Groups[newGroupID] + if !ok { return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", newGroupID, update.Id) } + if group.Name == "All" { + return status.Errorf(status.InvalidArgument, "can't add All group to the user") + } } return nil From 34114d4a55f400d1f74b8eed6d7574714f5c1007 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 12 Aug 2024 18:01:07 +0300 Subject: [PATCH 53/53] Fix peers update by including NetworkMap and posture Checks --- management/server/peer.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/management/server/peer.go b/management/server/peer.go index 7afe6ee0d..8c241c186 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -269,6 +269,8 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou FirewallRulesIsEmpty: true, }, }, + NetworkMap: &NetworkMap{}, + Checks: []*posture.Checks{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) @@ -955,7 +957,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account postureChecks := am.getPeerPostureChecks(account, p) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) + am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks}) }(peer) }