diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 14778e105..2d6f4dbf9 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "sync" - "sync/atomic" "time" "google.golang.org/grpc/codes" @@ -46,7 +45,8 @@ type GrpcClient struct { connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex // lastNetworkMapSerial stores last seen network map serial to optimize sync - lastNetworkMapSerial uint64 + lastNetworkMapSerial uint64 + lastNetworkMapSerialMu sync.Mutex } // NewClient creates a new client to Management service @@ -220,7 +220,7 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err } // update last seen serial - atomic.StoreUint64(&c.lastNetworkMapSerial, decryptedResp.GetNetworkMap().GetSerial()) + c.setLastNetworkMapSerial(decryptedResp.GetNetworkMap().GetSerial()) return decryptedResp.GetNetworkMap(), nil } @@ -235,7 +235,7 @@ func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.K recomputed.Files = sysInfo.Files } } - req := &proto.SyncRequest{Meta: infoToMetaData(recomputed), NetworkMapSerial: atomic.LoadUint64(&c.lastNetworkMapSerial)} + req := &proto.SyncRequest{Meta: infoToMetaData(recomputed), NetworkMapSerial: c.getLastNetworkMapSerial()} myPrivateKey := c.key myPublicKey := myPrivateKey.PublicKey() @@ -275,7 +275,7 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se // track latest network map serial if present if decryptedResp.GetNetworkMap() != nil { - atomic.StoreUint64(&c.lastNetworkMapSerial, decryptedResp.GetNetworkMap().GetSerial()) + c.setLastNetworkMapSerial(decryptedResp.GetNetworkMap().GetSerial()) } if err := msgHandler(decryptedResp); err != nil { @@ -602,3 +602,18 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { }, } } + +// setLastNetworkMapSerial updates the cached last seen network map serial in a 32-bit safe manner +func (c *GrpcClient) setLastNetworkMapSerial(serial uint64) { + c.lastNetworkMapSerialMu.Lock() + c.lastNetworkMapSerial = serial + c.lastNetworkMapSerialMu.Unlock() +} + +// getLastNetworkMapSerial returns the cached last seen network map serial in a 32-bit safe manner +func (c *GrpcClient) getLastNetworkMapSerial() uint64 { + c.lastNetworkMapSerialMu.Lock() + v := c.lastNetworkMapSerial + c.lastNetworkMapSerialMu.Unlock() + return v +}