From e3765417454338081d62f3afc74e463b07f3ed02 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 28 Mar 2023 17:10:15 +0200 Subject: [PATCH] Add grpc keep alive for management service --- management/client/grpc.go | 14 +++ management/cmd/management.go | 7 +- management/server/keep_alive.go | 171 ++++++++++++++++++++++++++++++++ 3 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 management/server/keep_alive.go diff --git a/management/client/grpc.go b/management/client/grpc.go index d2ca8c088..254c2693f 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -9,6 +9,7 @@ import ( "time" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" gstatus "google.golang.org/grpc/status" log "github.com/sirupsen/logrus" @@ -24,6 +25,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/version" ) // ConnStateNotifier is a wrapper interface of the status recorders @@ -67,6 +69,9 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE realClient := proto.NewManagementServiceClient(conn) + md := metadata.Pairs("version", version.NetbirdVersion()) + ctx = metadata.NewOutgoingContext(ctx, md) + return &GrpcClient{ key: ourPrivateKey, realClient: realClient, @@ -131,6 +136,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error ctx, cancelStream := context.WithCancel(c.ctx) defer cancelStream() + stream, err := c.connectToStream(ctx, *serverPubKey) if err != nil { log.Debugf("failed to open Management Service stream: %s", err) @@ -246,6 +252,10 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se return err } + if c.isKeepAliveMsg(update.Body) { + continue + } + log.Debugf("got an update message from Management Service") decryptedResp := &proto.SyncResponse{} err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp) @@ -386,6 +396,10 @@ func (c *GrpcClient) notifyConnected() { c.connStateCallback.MarkManagementConnected() } +func (c *GrpcClient) isKeepAliveMsg(body []byte) bool { + return len(body) == 0 +} + func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { if info == nil { return nil diff --git a/management/cmd/management.go b/management/cmd/management.go index eee8c5c57..e15525e4e 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -202,11 +202,16 @@ var ( return fmt.Errorf("failed creating HTTP API handler: %v", err) } - gRPCAPIHandler := grpc.NewServer(gRPCOpts...) srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } + + ka := server.NewKeepAlive() + sInterc := grpc.StreamInterceptor(ka.StreamInterceptor()) + uInterc := grpc.UnaryInterceptor(ka.UnaryInterceptor()) + gRPCOpts = append(gRPCOpts, sInterc, uInterc) + gRPCAPIHandler := grpc.NewServer(gRPCOpts...) mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) installationID, err := getInstallationID(store) diff --git a/management/server/keep_alive.go b/management/server/keep_alive.go new file mode 100644 index 000000000..4e9ff5406 --- /dev/null +++ b/management/server/keep_alive.go @@ -0,0 +1,171 @@ +package server + +import ( + "context" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/netbirdio/netbird/management/proto" +) + +const ( + reversProxyHeaderKey = "x-netbird-peer" + grpcVersionHeaderKey = "version" + keepAliveInterval = 30 * time.Second +) + +type ioMonitor struct { + grpc.ServerStream + mu sync.Mutex + lastSeen time.Time +} + +func (l *ioMonitor) SendMsg(m interface{}) error { + l.updateLastSeen() + return l.ServerStream.SendMsg(m) +} + +func (l *ioMonitor) updateLastSeen() { + l.mu.Lock() + defer l.mu.Unlock() + l.lastSeen = time.Now() +} + +func (l *ioMonitor) getLastSeen() time.Time { + l.mu.Lock() + t := l.lastSeen + l.mu.Unlock() + return t +} + +type KeepAlive struct { + sync.RWMutex + ticker *time.Ticker + done chan struct{} + streams map[string]*ioMonitor +} + +// todo: write free resources function + +func NewKeepAlive() *KeepAlive { + ka := &KeepAlive{ + ticker: time.NewTicker(1 * time.Second), + done: make(chan struct{}), + streams: make(map[string]*ioMonitor), + } + go ka.start() + return ka +} + +func (k *KeepAlive) StreamInterceptor() grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + address, supported := k.keepAliveIsSupported(stream.Context()) + if !supported { + return handler(srv, stream) + } + + m := &ioMonitor{ + stream, + sync.Mutex{}, + time.Now(), + } + + k.addIoMonitor(address, m) + + return handler(srv, m) + } +} + +func (k *KeepAlive) UnaryInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + address, supported := k.keepAliveIsSupported(ctx) + if supported { + k.updateLastSeen(address) + } + return handler(ctx, req) + } +} + +func (k *KeepAlive) start() { + for { + select { + case <-k.done: + return + case t := <-k.ticker.C: + k.checkKeepAlive(t) + } + } +} + +func (k *KeepAlive) checkKeepAlive(now time.Time) { + k.Lock() + defer k.Unlock() + for addr, m := range k.streams { + if k.isKeepAliveOutDated(now, m) { + continue + } + log.Debugf("send keepalive for: %s", addr) + err := k.sendKeepAlive(m) + if err != nil { + log.Debugf("stop keepalive for: %s", addr) + delete(k.streams, addr) + } + } +} + +func (k *KeepAlive) keepAliveIsSupported(ctx context.Context) (string, bool) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + log.Warnf("metadata not found") + return "", false + } + + peerAddress := k.addressFromHeader(md) + if peerAddress == "" { + log.Debugf("peer is not using reverse proxy") + return "", false + } + + if len(md.Get(grpcVersionHeaderKey)) == 0 { + log.Debugf("version info not found") + return "", false + } + return peerAddress, true +} + +func (k *KeepAlive) addIoMonitor(address string, m *ioMonitor) { + k.Lock() + defer k.Unlock() + k.streams[address] = m +} + +func (k *KeepAlive) sendKeepAlive(m *ioMonitor) error { + msg := &proto.Empty{} + return m.SendMsg(msg) +} + +func (k *KeepAlive) updateLastSeen(address string) { + k.RLock() + m, ok := k.streams[address] + k.RUnlock() + if !ok { + return + } + m.updateLastSeen() +} + +func (k *KeepAlive) addressFromHeader(md metadata.MD) string { + peer := md.Get(reversProxyHeaderKey) + if len(peer) == 0 { + return "" + } + return peer[0] +} + +func (k *KeepAlive) isKeepAliveOutDated(now time.Time, m *ioMonitor) bool { + return now.Sub(m.getLastSeen()) < keepAliveInterval +}