From 20f5f0063511bf8d93a8e38206d6e7d5fc87c5ac Mon Sep 17 00:00:00 2001 From: Hakan Sariman Date: Fri, 17 Oct 2025 10:03:07 +0300 Subject: [PATCH] [client] Add unit tests for engine synchronization and Info flag copying - Introduced tests for the Engine's handleSync method to verify behavior when SkipNetworkMapUpdate is true and when NetworkMap is nil. - Added a test for the Info struct to ensure correct copying of flag values from one instance to another, while preserving unrelated fields. --- client/internal/engine_sync_test.go | 79 +++++++++++++++++++++++++++ client/system/info_test.go | 67 +++++++++++++++++++++-- shared/management/client/grpc_test.go | 26 +++++++++ 3 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 client/internal/engine_sync_test.go create mode 100644 shared/management/client/grpc_test.go diff --git a/client/internal/engine_sync_test.go b/client/internal/engine_sync_test.go new file mode 100644 index 000000000..ad2a97157 --- /dev/null +++ b/client/internal/engine_sync_test.go @@ -0,0 +1,79 @@ +package internal + +import ( + "context" + "testing" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/shared/management/client" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +// Ensures handleSync exits early when SkipNetworkMapUpdate is true +func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{ + WgIfaceName: "utun199", + WgAddr: "100.70.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + MTU: iface.DefaultMTU, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + engine.ctx = ctx + + // Precondition + if engine.networkSerial != 0 { + t.Fatalf("unexpected initial serial: %d", engine.networkSerial) + } + + resp := &mgmtProto.SyncResponse{ + NetworkMap: &mgmtProto.NetworkMap{Serial: 42}, + SkipNetworkMapUpdate: true, + } + + if err := engine.handleSync(resp); err != nil { + t.Fatalf("handleSync returned error: %v", err) + } + + if engine.networkSerial != 0 { + t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial) + } +} + +// Ensures handleSync exits early when NetworkMap is nil +func TestEngine_HandleSync_NilNetworkMap(t *testing.T) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{ + WgIfaceName: "utun198", + WgAddr: "100.70.0.2/24", + WgPrivateKey: key, + WgPort: 33101, + MTU: iface.DefaultMTU, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + engine.ctx = ctx + + resp := &mgmtProto.SyncResponse{NetworkMap: nil} + + if err := engine.handleSync(resp); err != nil { + t.Fatalf("handleSync returned error: %v", err) + } +} + + diff --git a/client/system/info_test.go b/client/system/info_test.go index 27821f3c5..fdd2895ac 100644 --- a/client/system/info_test.go +++ b/client/system/info_test.go @@ -1,13 +1,72 @@ package system import ( - "context" - "testing" + "context" + "testing" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc/metadata" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" ) +func TestInfo_CopyFlagsFrom(t *testing.T) { + origin := &Info{} + serverSSHAllowed := true + origin.SetFlags( + true, // RosenpassEnabled + false, // RosenpassPermissive + &serverSSHAllowed, + true, // DisableClientRoutes + false, // DisableServerRoutes + true, // DisableDNS + false, // DisableFirewall + true, // BlockLANAccess + false, // BlockInbound + true, // LazyConnectionEnabled + ) + + got := &Info{} + got.CopyFlagsFrom(origin) + + if got.RosenpassEnabled != true { + t.Fatalf("RosenpassEnabled not copied: got %v", got.RosenpassEnabled) + } + if got.RosenpassPermissive != false { + t.Fatalf("RosenpassPermissive not copied: got %v", got.RosenpassPermissive) + } + if got.ServerSSHAllowed != true { + t.Fatalf("ServerSSHAllowed not copied: got %v", got.ServerSSHAllowed) + } + if got.DisableClientRoutes != true { + t.Fatalf("DisableClientRoutes not copied: got %v", got.DisableClientRoutes) + } + if got.DisableServerRoutes != false { + t.Fatalf("DisableServerRoutes not copied: got %v", got.DisableServerRoutes) + } + if got.DisableDNS != true { + t.Fatalf("DisableDNS not copied: got %v", got.DisableDNS) + } + if got.DisableFirewall != false { + t.Fatalf("DisableFirewall not copied: got %v", got.DisableFirewall) + } + if got.BlockLANAccess != true { + t.Fatalf("BlockLANAccess not copied: got %v", got.BlockLANAccess) + } + if got.BlockInbound != false { + t.Fatalf("BlockInbound not copied: got %v", got.BlockInbound) + } + if got.LazyConnectionEnabled != true { + t.Fatalf("LazyConnectionEnabled not copied: got %v", got.LazyConnectionEnabled) + } + + // ensure CopyFlagsFrom does not touch unrelated fields + origin.Hostname = "host-a" + got.Hostname = "host-b" + got.CopyFlagsFrom(origin) + if got.Hostname != "host-b" { + t.Fatalf("CopyFlagsFrom should not overwrite non-flag fields, got Hostname=%q", got.Hostname) + } +} + func Test_LocalWTVersion(t *testing.T) { got := GetInfo(context.TODO()) want := "development" diff --git a/shared/management/client/grpc_test.go b/shared/management/client/grpc_test.go new file mode 100644 index 000000000..fccd2179e --- /dev/null +++ b/shared/management/client/grpc_test.go @@ -0,0 +1,26 @@ +package client + +import ( + "testing" +) + +func TestGrpcClient_LastNetworkMapSerial_SetGet(t *testing.T) { + c := &GrpcClient{} + + if got := c.getLastNetworkMapSerial(); got != 0 { + t.Fatalf("initial serial should be 0, got %d", got) + } + + c.setLastNetworkMapSerial(123) + if got := c.getLastNetworkMapSerial(); got != 123 { + t.Fatalf("serial after set should be 123, got %d", got) + } + + // overwrite should work + c.setLastNetworkMapSerial(5) + if got := c.getLastNetworkMapSerial(); got != 5 { + t.Fatalf("serial after overwrite should be 5, got %d", got) + } +} + +