diff --git a/management/server/keep_alive.go b/keepalive/keep_alive.go similarity index 76% rename from management/server/keep_alive.go rename to keepalive/keep_alive.go index 4e9ff5406..a1182bd45 100644 --- a/management/server/keep_alive.go +++ b/keepalive/keep_alive.go @@ -1,4 +1,4 @@ -package server +package keepalive import ( "context" @@ -8,8 +8,6 @@ import ( log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/metadata" - - "github.com/netbirdio/netbird/management/proto" ) const ( @@ -18,44 +16,20 @@ const ( keepAliveInterval = 30 * time.Second ) -type ioMonitor struct { - grpc.ServerStream - mu sync.Mutex - lastSeen time.Time -} - -func (l *ioMonitor) SendMsg(m interface{}) error { - l.updateLastSeen() - return l.ServerStream.SendMsg(m) -} - -func (l *ioMonitor) updateLastSeen() { - l.mu.Lock() - defer l.mu.Unlock() - l.lastSeen = time.Now() -} - -func (l *ioMonitor) getLastSeen() time.Time { - l.mu.Lock() - t := l.lastSeen - l.mu.Unlock() - return t -} - type KeepAlive struct { sync.RWMutex - ticker *time.Ticker - done chan struct{} - streams map[string]*ioMonitor + ticker *time.Ticker + done chan struct{} + streams map[string]*ioMonitor + keepAliveMsg interface{} } -// todo: write free resources function - -func NewKeepAlive() *KeepAlive { +func NewKeepAlive(keepAliveMsg interface{}) *KeepAlive { ka := &KeepAlive{ - ticker: time.NewTicker(1 * time.Second), - done: make(chan struct{}), - streams: make(map[string]*ioMonitor), + ticker: time.NewTicker(1 * time.Second), + done: make(chan struct{}), + streams: make(map[string]*ioMonitor), + keepAliveMsg: keepAliveMsg, } go ka.start() return ka @@ -69,8 +43,8 @@ func (k *KeepAlive) StreamInterceptor() grpc.StreamServerInterceptor { } m := &ioMonitor{ - stream, sync.Mutex{}, + stream, time.Now(), } @@ -90,6 +64,15 @@ func (k *KeepAlive) UnaryInterceptor() grpc.UnaryServerInterceptor { } } +func (k *KeepAlive) Stop() { + select { + case k.done <- struct{}{}: + k.ticker.Stop() + return + default: + } +} + func (k *KeepAlive) start() { for { select { @@ -108,7 +91,6 @@ func (k *KeepAlive) checkKeepAlive(now time.Time) { if k.isKeepAliveOutDated(now, m) { continue } - log.Debugf("send keepalive for: %s", addr) err := k.sendKeepAlive(m) if err != nil { log.Debugf("stop keepalive for: %s", addr) @@ -144,8 +126,7 @@ func (k *KeepAlive) addIoMonitor(address string, m *ioMonitor) { } func (k *KeepAlive) sendKeepAlive(m *ioMonitor) error { - msg := &proto.Empty{} - return m.SendMsg(msg) + return m.sendMsg(k.keepAliveMsg) } func (k *KeepAlive) updateLastSeen(address string) { diff --git a/keepalive/monitor.go b/keepalive/monitor.go new file mode 100644 index 000000000..3fe50c6f6 --- /dev/null +++ b/keepalive/monitor.go @@ -0,0 +1,32 @@ +package keepalive + +import ( + "sync" + "time" + + "google.golang.org/grpc" +) + +type ioMonitor struct { + mu sync.Mutex + grpc.ServerStream + lastSeen time.Time +} + +func (l *ioMonitor) sendMsg(m interface{}) error { + l.updateLastSeen() + return l.ServerStream.SendMsg(m) +} + +func (l *ioMonitor) updateLastSeen() { + l.mu.Lock() + defer l.mu.Unlock() + l.lastSeen = time.Now() +} + +func (l *ioMonitor) getLastSeen() time.Time { + l.mu.Lock() + t := l.lastSeen + l.mu.Unlock() + return t +} diff --git a/management/cmd/management.go b/management/cmd/management.go index e15525e4e..b5e1360ce 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -40,6 +40,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/encryption" + grpcKeepAlive "github.com/netbirdio/netbird/keepalive" mgmtProto "github.com/netbirdio/netbird/management/proto" ) @@ -207,7 +208,8 @@ var ( return fmt.Errorf("failed creating gRPC API handler: %v", err) } - ka := server.NewKeepAlive() + ka := grpcKeepAlive.NewKeepAlive(&mgmtProto.Empty{}) + defer ka.Stop() sInterc := grpc.StreamInterceptor(ka.StreamInterceptor()) uInterc := grpc.UnaryInterceptor(ka.UnaryInterceptor()) gRPCOpts = append(gRPCOpts, sInterc, uInterc) diff --git a/management/proto/management.proto b/management/proto/management.proto index e1c0fee37..46c86aeed 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -329,3 +329,5 @@ message FirewallRule { ICMP = 4; } } + +message KeepAlive {} diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 9b52fb52d..e2b6e65c1 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -4,7 +4,6 @@ import ( "errors" "flag" "fmt" - "golang.org/x/crypto/acme/autocert" "io" "io/fs" "net" @@ -14,15 +13,18 @@ import ( "strings" "time" - "github.com/netbirdio/netbird/encryption" - "github.com/netbirdio/netbird/signal/proto" - "github.com/netbirdio/netbird/signal/server" - "github.com/netbirdio/netbird/util" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.org/x/crypto/acme/autocert" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + + "github.com/netbirdio/netbird/encryption" + appKeepAlive "github.com/netbirdio/netbird/keepalive" + "github.com/netbirdio/netbird/signal/proto" + "github.com/netbirdio/netbird/signal/server" + "github.com/netbirdio/netbird/util" ) var ( @@ -93,6 +95,13 @@ var ( } opts = append(opts, signalKaep, signalKasp) + + ka := appKeepAlive.NewKeepAlive(&proto.KeepAlive{}) + defer ka.Stop() + sInterc := grpc.StreamInterceptor(ka.StreamInterceptor()) + uInterc := grpc.UnaryInterceptor(ka.UnaryInterceptor()) + opts = append(opts, sInterc, uInterc) + grpcServer := grpc.NewServer(opts...) proto.RegisterSignalExchangeServer(grpcServer, server.NewServer()) diff --git a/signal/proto/signalexchange.proto b/signal/proto/signalexchange.proto index 18c918d97..36e896fde 100644 --- a/signal/proto/signalexchange.proto +++ b/signal/proto/signalexchange.proto @@ -62,4 +62,6 @@ message Body { // Mode indicates a connection mode message Mode { optional bool direct = 1; -} \ No newline at end of file +} + +message KeepAlive {} \ No newline at end of file