mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
287 lines
6.9 KiB
Go
287 lines
6.9 KiB
Go
package grpc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/keepalive"
|
|
|
|
pb "github.com/netbirdio/netbird/proxy/pkg/grpc/proto"
|
|
)
|
|
|
|
// StreamHandler handles incoming messages from control service
|
|
type StreamHandler interface {
|
|
HandleControlEvent(ctx context.Context, event *pb.ControlEvent) error
|
|
HandleControlCommand(ctx context.Context, command *pb.ControlCommand) error
|
|
HandleControlConfig(ctx context.Context, config *pb.ControlConfig) error
|
|
HandleExposedServiceEvent(ctx context.Context, event *pb.ExposedServiceEvent) error
|
|
}
|
|
|
|
// Server represents the gRPC server running on the proxy
|
|
type Server struct {
|
|
pb.UnimplementedProxyServiceServer
|
|
|
|
listenAddr string
|
|
grpcServer *grpc.Server
|
|
handler StreamHandler
|
|
|
|
mu sync.RWMutex
|
|
streams map[string]*StreamContext
|
|
isRunning bool
|
|
}
|
|
|
|
// StreamContext holds the context for each active stream
|
|
type StreamContext struct {
|
|
stream pb.ProxyService_StreamServer
|
|
sendChan chan *pb.ProxyMessage
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
controlID string // ID of the connected control service
|
|
}
|
|
|
|
// Config holds gRPC server configuration
|
|
type Config struct {
|
|
ListenAddr string
|
|
Handler StreamHandler
|
|
}
|
|
|
|
// NewServer creates a new gRPC server
|
|
func NewServer(config Config) *Server {
|
|
return &Server{
|
|
listenAddr: config.ListenAddr,
|
|
handler: config.Handler,
|
|
streams: make(map[string]*StreamContext),
|
|
}
|
|
}
|
|
|
|
// Start starts the gRPC server
|
|
func (s *Server) Start() error {
|
|
s.mu.Lock()
|
|
if s.isRunning {
|
|
s.mu.Unlock()
|
|
return fmt.Errorf("gRPC server already running")
|
|
}
|
|
s.isRunning = true
|
|
s.mu.Unlock()
|
|
|
|
lis, err := net.Listen("tcp", s.listenAddr)
|
|
if err != nil {
|
|
s.mu.Lock()
|
|
s.isRunning = false
|
|
s.mu.Unlock()
|
|
return fmt.Errorf("failed to listen: %w", err)
|
|
}
|
|
|
|
// Configure gRPC server with keepalive
|
|
s.grpcServer = grpc.NewServer(
|
|
grpc.KeepaliveParams(keepalive.ServerParameters{
|
|
Time: 30 * time.Second,
|
|
Timeout: 10 * time.Second,
|
|
}),
|
|
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
|
|
MinTime: 10 * time.Second,
|
|
PermitWithoutStream: true,
|
|
}),
|
|
)
|
|
|
|
pb.RegisterProxyServiceServer(s.grpcServer, s)
|
|
|
|
log.Infof("gRPC server listening on %s", s.listenAddr)
|
|
|
|
if err := s.grpcServer.Serve(lis); err != nil {
|
|
s.mu.Lock()
|
|
s.isRunning = false
|
|
s.mu.Unlock()
|
|
return fmt.Errorf("failed to serve: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop gracefully stops the gRPC server
|
|
func (s *Server) Stop(ctx context.Context) error {
|
|
s.mu.Lock()
|
|
if !s.isRunning {
|
|
s.mu.Unlock()
|
|
return fmt.Errorf("gRPC server not running")
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
log.Info("Stopping gRPC server...")
|
|
|
|
// Cancel all active streams
|
|
s.mu.Lock()
|
|
for _, streamCtx := range s.streams {
|
|
streamCtx.cancel()
|
|
close(streamCtx.sendChan)
|
|
}
|
|
s.streams = make(map[string]*StreamContext)
|
|
s.mu.Unlock()
|
|
|
|
// Graceful stop with timeout
|
|
stopped := make(chan struct{})
|
|
go func() {
|
|
s.grpcServer.GracefulStop()
|
|
close(stopped)
|
|
}()
|
|
|
|
select {
|
|
case <-stopped:
|
|
log.Info("gRPC server stopped gracefully")
|
|
case <-ctx.Done():
|
|
log.Warn("gRPC server graceful stop timeout, forcing stop")
|
|
s.grpcServer.Stop()
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.isRunning = false
|
|
s.mu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stream implements the bidirectional streaming RPC
|
|
// The control service connects as client, proxy is server
|
|
// Control service sends ControlMessage, Proxy sends ProxyMessage
|
|
func (s *Server) Stream(stream pb.ProxyService_StreamServer) error {
|
|
ctx, cancel := context.WithCancel(stream.Context())
|
|
defer cancel()
|
|
|
|
controlID := fmt.Sprintf("control-%d", time.Now().Unix())
|
|
|
|
// Create stream context
|
|
streamCtx := &StreamContext{
|
|
stream: stream,
|
|
sendChan: make(chan *pb.ProxyMessage, 100),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
controlID: controlID,
|
|
}
|
|
|
|
// Register stream
|
|
s.mu.Lock()
|
|
s.streams[controlID] = streamCtx
|
|
s.mu.Unlock()
|
|
|
|
log.Infof("Control service connected: %s", controlID)
|
|
|
|
// Start goroutine to send ProxyMessages to control service
|
|
sendDone := make(chan error, 1)
|
|
go s.sendLoop(streamCtx, sendDone)
|
|
|
|
// Start goroutine to receive ControlMessages from control service
|
|
recvDone := make(chan error, 1)
|
|
go s.receiveLoop(streamCtx, recvDone)
|
|
|
|
// Wait for either send or receive to complete
|
|
select {
|
|
case err := <-sendDone:
|
|
log.Infof("Control service %s send loop ended: %v", controlID, err)
|
|
return err
|
|
case err := <-recvDone:
|
|
log.Infof("Control service %s receive loop ended: %v", controlID, err)
|
|
return err
|
|
case <-ctx.Done():
|
|
log.Infof("Control service %s context done: %v", controlID, ctx.Err())
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
// sendLoop handles sending ProxyMessages to the control service
|
|
func (s *Server) sendLoop(streamCtx *StreamContext, done chan<- error) {
|
|
for {
|
|
select {
|
|
case msg, ok := <-streamCtx.sendChan:
|
|
if !ok {
|
|
done <- nil
|
|
return
|
|
}
|
|
|
|
// Send ProxyMessage to control service
|
|
if err := streamCtx.stream.Send(msg); err != nil {
|
|
log.Errorf("Failed to send message to control service: %v", err)
|
|
done <- err
|
|
return
|
|
}
|
|
|
|
case <-streamCtx.ctx.Done():
|
|
done <- streamCtx.ctx.Err()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// receiveLoop handles receiving ControlMessages from the control service
|
|
func (s *Server) receiveLoop(streamCtx *StreamContext, done chan<- error) {
|
|
for {
|
|
// Receive ControlMessage from control service (client)
|
|
controlMsg, err := streamCtx.stream.Recv()
|
|
if err != nil {
|
|
log.Debugf("Stream receive error: %v", err)
|
|
done <- err
|
|
return
|
|
}
|
|
|
|
// Handle different ControlMessage types
|
|
switch m := controlMsg.Message.(type) {
|
|
case *pb.ControlMessage_Event:
|
|
if s.handler != nil {
|
|
if err := s.handler.HandleControlEvent(streamCtx.ctx, m.Event); err != nil {
|
|
log.Errorf("Failed to handle control event: %v", err)
|
|
}
|
|
}
|
|
|
|
case *pb.ControlMessage_Command:
|
|
if s.handler != nil {
|
|
if err := s.handler.HandleControlCommand(streamCtx.ctx, m.Command); err != nil {
|
|
log.Errorf("Failed to handle control command: %v", err)
|
|
}
|
|
}
|
|
|
|
case *pb.ControlMessage_Config:
|
|
if s.handler != nil {
|
|
if err := s.handler.HandleControlConfig(streamCtx.ctx, m.Config); err != nil {
|
|
log.Errorf("Failed to handle control config: %v", err)
|
|
}
|
|
}
|
|
|
|
case *pb.ControlMessage_ExposedService:
|
|
if s.handler != nil {
|
|
if err := s.handler.HandleExposedServiceEvent(streamCtx.ctx, m.ExposedService); err != nil {
|
|
log.Errorf("Failed to handle exposed service event: %v", err)
|
|
}
|
|
}
|
|
|
|
default:
|
|
log.Warnf("Received unknown control message type: %T", m)
|
|
}
|
|
}
|
|
}
|
|
|
|
// SendProxyMessage sends a ProxyMessage to all connected control services
|
|
func (s *Server) SendProxyMessage(msg *pb.ProxyMessage) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, streamCtx := range s.streams {
|
|
select {
|
|
case streamCtx.sendChan <- msg:
|
|
// Message queued successfully
|
|
default:
|
|
log.Warn("Send channel full, dropping message")
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetActiveStreams returns the number of active streams
|
|
func (s *Server) GetActiveStreams() int {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return len(s.streams)
|
|
}
|