mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[management] Use DI containers for server bootstrapping (#4343)
This commit is contained in:
204
management/internals/server/boot.go
Normal file
204
management/internals/server/boot.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package server
|
||||
|
||||
// @note this file includes all the lower level dependencies, db, http and grpc BaseServer, metrics, logger, etc.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}
|
||||
|
||||
kasp = keepalive.ServerParameters{
|
||||
MaxConnectionIdle: 15 * time.Second,
|
||||
MaxConnectionAgeGrace: 5 * time.Second,
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
)
|
||||
|
||||
func (s *BaseServer) Metrics() telemetry.AppMetrics {
|
||||
return Create(s, func() telemetry.AppMetrics {
|
||||
appMetrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
log.Fatalf("error while creating app metrics: %s", err)
|
||||
}
|
||||
return appMetrics
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) Store() store.Store {
|
||||
return Create(s, func() store.Store {
|
||||
store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create store: %v", err)
|
||||
}
|
||||
|
||||
return store
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) EventStore() activity.Store {
|
||||
return Create(s, func() activity.Store {
|
||||
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
||||
}
|
||||
|
||||
eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize event store: %v", err)
|
||||
}
|
||||
|
||||
if s.config.DataStoreEncryptionKey != key {
|
||||
log.WithContext(context.Background()).Infof("update config with activity store key")
|
||||
s.config.DataStoreEncryptionKey = key
|
||||
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to update config with activity store: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return eventStore
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
return httpAPIHandler
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
return Create(s, func() *grpc.Server {
|
||||
trustedPeers := s.config.ReverseProxy.TrustedPeers
|
||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
|
||||
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
||||
trustedPeers = defaultTrustedPeers
|
||||
}
|
||||
trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies
|
||||
trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount
|
||||
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
|
||||
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
||||
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
|
||||
}
|
||||
realipOpts := []realip.Option{
|
||||
realip.WithTrustedPeers(trustedPeers),
|
||||
realip.WithTrustedProxies(trustedHTTPProxies),
|
||||
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
||||
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
||||
}
|
||||
gRPCOpts := []grpc.ServerOption{
|
||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||
grpc.KeepaliveParams(kasp),
|
||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
||||
}
|
||||
|
||||
if s.config.HttpConfig.LetsEncryptDomain != "" {
|
||||
certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create certificate manager: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot load TLS credentials: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(tlsConfig)
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||
|
||||
return gRPCAPIHandler
|
||||
})
|
||||
}
|
||||
|
||||
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
||||
// Load server's certificate and private key
|
||||
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// NewDefaultAppMetrics the credentials and return it
|
||||
config := &tls.Config{
|
||||
Certificates: []tls.Certificate{serverCert},
|
||||
ClientAuth: tls.NoClientCert,
|
||||
NextProtos: []string{
|
||||
"h2", "http/1.1", // enable HTTP/2
|
||||
},
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func unaryInterceptor(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
reqID := uuid.New().String()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
func streamInterceptor(
|
||||
srv interface{},
|
||||
ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler,
|
||||
) error {
|
||||
reqID := uuid.New().String()
|
||||
wrapped := grpcMiddleware.WrapServerStream(ss)
|
||||
//nolint
|
||||
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)
|
||||
//nolint
|
||||
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
return handler(srv, wrapped)
|
||||
}
|
||||
192
management/internals/server/config/config.go
Normal file
192
management/internals/server/config/config.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type (
|
||||
// Protocol type
|
||||
Protocol string
|
||||
|
||||
// Provider authorization flow type
|
||||
Provider string
|
||||
)
|
||||
|
||||
const (
|
||||
UDP Protocol = "udp"
|
||||
DTLS Protocol = "dtls"
|
||||
TCP Protocol = "tcp"
|
||||
HTTP Protocol = "http"
|
||||
HTTPS Protocol = "https"
|
||||
NONE Provider = "none"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDeviceAuthFlowScope defines the bare minimum scope to request in the device authorization flow
|
||||
DefaultDeviceAuthFlowScope string = "openid"
|
||||
)
|
||||
|
||||
var MgmtConfigPath string
|
||||
|
||||
// Config of the Management service
|
||||
type Config struct {
|
||||
Stuns []*Host
|
||||
TURNConfig *TURNConfig
|
||||
Relay *Relay
|
||||
Signal *Host
|
||||
|
||||
Datadir string
|
||||
DataStoreEncryptionKey string
|
||||
|
||||
HttpConfig *HttpServerConfig
|
||||
|
||||
IdpManagerConfig *idp.Config
|
||||
|
||||
DeviceAuthorizationFlow *DeviceAuthorizationFlow
|
||||
|
||||
PKCEAuthorizationFlow *PKCEAuthorizationFlow
|
||||
|
||||
StoreConfig StoreConfig
|
||||
|
||||
ReverseProxy ReverseProxy
|
||||
|
||||
// disable default all-to-all policy
|
||||
DisableDefaultPolicy bool
|
||||
}
|
||||
|
||||
// GetAuthAudiences returns the audience from the http config and device authorization flow config
|
||||
func (c Config) GetAuthAudiences() []string {
|
||||
audiences := []string{c.HttpConfig.AuthAudience}
|
||||
|
||||
if c.HttpConfig.ExtraAuthAudience != "" {
|
||||
audiences = append(audiences, c.HttpConfig.ExtraAuthAudience)
|
||||
}
|
||||
|
||||
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
|
||||
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
|
||||
}
|
||||
|
||||
return audiences
|
||||
}
|
||||
|
||||
// TURNConfig is a config of the TURNCredentialsManager
|
||||
type TURNConfig struct {
|
||||
TimeBasedCredentials bool
|
||||
CredentialsTTL util.Duration
|
||||
Secret string
|
||||
Turns []*Host
|
||||
}
|
||||
|
||||
// Relay configuration type
|
||||
type Relay struct {
|
||||
Addresses []string
|
||||
CredentialsTTL util.Duration
|
||||
Secret string
|
||||
}
|
||||
|
||||
// HttpServerConfig is a config of the HTTP Management service server
|
||||
type HttpServerConfig struct {
|
||||
LetsEncryptDomain string
|
||||
// CertFile is the location of the certificate
|
||||
CertFile string
|
||||
// CertKey is the location of the certificate private key
|
||||
CertKey string
|
||||
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
||||
AuthAudience string
|
||||
// AuthIssuer identifies principal that issued the JWT
|
||||
AuthIssuer string
|
||||
// AuthUserIDClaim is the name of the claim that used as user ID
|
||||
AuthUserIDClaim string
|
||||
// AuthKeysLocation is a location of JWT key set containing the public keys used to verify JWT
|
||||
AuthKeysLocation string
|
||||
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
|
||||
OIDCConfigEndpoint string
|
||||
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
|
||||
IdpSignKeyRefreshEnabled bool
|
||||
// Extra audience
|
||||
ExtraAuthAudience string
|
||||
}
|
||||
|
||||
// Host represents a Netbird host (e.g. STUN, TURN, Signal)
|
||||
type Host struct {
|
||||
Proto Protocol
|
||||
// URI e.g. turns://stun.netbird.io:4430 or signal.netbird.io:10000
|
||||
URI string
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
// that can be used by the client to login initiate a Oauth 2.0 device authorization grant flow
|
||||
// see https://datatracker.ietf.org/doc/html/rfc8628
|
||||
type DeviceAuthorizationFlow struct {
|
||||
Provider string
|
||||
ProviderConfig ProviderConfig
|
||||
}
|
||||
|
||||
// PKCEAuthorizationFlow represents Authorization Code Flow information
|
||||
// that can be used by the client to login initiate a Oauth 2.0 authorization code grant flow
|
||||
// with Proof Key for Code Exchange (PKCE). See https://datatracker.ietf.org/doc/html/rfc7636
|
||||
type PKCEAuthorizationFlow struct {
|
||||
ProviderConfig ProviderConfig
|
||||
}
|
||||
|
||||
// ProviderConfig has all attributes needed to initiate a device/pkce authorization flow
|
||||
type ProviderConfig struct {
|
||||
// ClientID An IDP application client id
|
||||
ClientID string
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use TokenEndpoint and DeviceAuthEndpoint
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||
AuthorizationEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// RedirectURL handles authorization code from IDP manager
|
||||
RedirectURLs []string
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
}
|
||||
|
||||
// StoreConfig contains Store configuration
|
||||
type StoreConfig struct {
|
||||
Engine types.Engine
|
||||
}
|
||||
|
||||
// ReverseProxy contains reverse proxy configuration in front of management.
|
||||
type ReverseProxy struct {
|
||||
// TrustedHTTPProxies represents a list of trusted HTTP proxies by their IP prefixes.
|
||||
// When extracting the real IP address from request headers, the middleware will verify
|
||||
// if the peer's address falls within one of these trusted IP prefixes.
|
||||
TrustedHTTPProxies []netip.Prefix
|
||||
|
||||
// TrustedHTTPProxiesCount specifies the count of trusted HTTP proxies between the internet
|
||||
// and the server. When using the trusted proxy count method to extract the real IP address,
|
||||
// the middleware will search the X-Forwarded-For IP list from the rightmost by this count
|
||||
// minus one.
|
||||
TrustedHTTPProxiesCount uint
|
||||
|
||||
// TrustedPeers represents a list of trusted peers by their IP prefixes.
|
||||
// These peers are considered trustworthy by the gRPC server operator,
|
||||
// and the middleware will attempt to extract the real IP address from
|
||||
// request headers if the peer's address falls within one of these
|
||||
// trusted IP prefixes.
|
||||
TrustedPeers []netip.Prefix
|
||||
}
|
||||
55
management/internals/server/container.go
Normal file
55
management/internals/server/container.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package server
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Create a dependency and add it to the BaseServer's container. A string key identifier will be based on its type definition.
|
||||
func Create[T any](s Server, createFunc func() T) T {
|
||||
result, _ := maybeCreate(s, createFunc)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// CreateNamed is the same as Create but will suffix the dependency string key identifier with a custom name.
|
||||
// Useful if you want to have multiple named instances of the same object type.
|
||||
func CreateNamed[T any](s Server, name string, createFunc func() T) T {
|
||||
result, _ := maybeCreateNamed(s, name, createFunc)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Inject lets you override a specific service from outside the BaseServer itself.
|
||||
// This is useful for tests
|
||||
func Inject[T any](c Server, thing T) {
|
||||
_, _ = maybeCreate(c, func() T {
|
||||
return thing
|
||||
})
|
||||
}
|
||||
|
||||
// InjectNamed is like Inject() but with a custom name.
|
||||
func InjectNamed[T any](c Server, name string, thing T) {
|
||||
_, _ = maybeCreateKeyed(c, name, func() T {
|
||||
return thing
|
||||
})
|
||||
}
|
||||
|
||||
func maybeCreate[T any](s Server, createFunc func() T) (result T, isNew bool) {
|
||||
key := fmt.Sprintf("%T", (*T)(nil))[1:]
|
||||
return maybeCreateKeyed(s, key, createFunc)
|
||||
}
|
||||
|
||||
func maybeCreateNamed[T any](s Server, name string, createFunc func() T) (result T, isNew bool) {
|
||||
key := fmt.Sprintf("%T:%s", (*T)(nil), name)[1:]
|
||||
return maybeCreateKeyed(s, key, createFunc)
|
||||
}
|
||||
|
||||
func maybeCreateKeyed[T any](s Server, key string, createFunc func() T) (result T, isNew bool) {
|
||||
if t, ok := s.GetContainer(key); ok {
|
||||
return t.(T), false
|
||||
}
|
||||
|
||||
t := createFunc()
|
||||
|
||||
s.SetContainer(key, t)
|
||||
|
||||
return t, true
|
||||
}
|
||||
59
management/internals/server/controllers.go
Normal file
59
management/internals/server/controllers.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
)
|
||||
|
||||
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
|
||||
return Create(s, func() *server.PeersUpdateManager {
|
||||
return server.NewPeersUpdateManager(s.Metrics())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())
|
||||
if err != nil {
|
||||
log.Errorf("failed to create integrated peer validator: %v", err)
|
||||
}
|
||||
return integratedPeerValidator
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ProxyController() port_forwarding.Controller {
|
||||
return Create(s, func() port_forwarding.Controller {
|
||||
return integrations.NewController(s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager {
|
||||
return Create(s, func() *server.TimeBasedAuthSecretsManager {
|
||||
return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AuthManager() auth.Manager {
|
||||
return Create(s, func() auth.Manager {
|
||||
return auth.NewManager(s.Store(),
|
||||
s.config.HttpConfig.AuthIssuer,
|
||||
s.config.HttpConfig.AuthAudience,
|
||||
s.config.HttpConfig.AuthKeysLocation,
|
||||
s.config.HttpConfig.AuthUserIDClaim,
|
||||
s.config.GetAuthAudiences(),
|
||||
s.config.HttpConfig.IdpSignKeyRefreshEnabled)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) EphemeralManager() *server.EphemeralManager {
|
||||
return Create(s, func() *server.EphemeralManager {
|
||||
return server.NewEphemeralManager(s.Store(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
108
management/internals/server/modules.go
Normal file
108
management/internals/server/modules.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
)
|
||||
|
||||
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
||||
return Create(s, func() geolocation.Geolocation {
|
||||
geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate)
|
||||
if err != nil {
|
||||
log.Warnf("could not initialize geolocation service. proceeding without geolocation support: %v", err)
|
||||
} else {
|
||||
log.Infof("geolocation service has been initialized from %s", s.config.Datadir)
|
||||
}
|
||||
|
||||
return geo
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||
return Create(s, func() permissions.Manager {
|
||||
return integrations.InitPermissionsManager(s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) UsersManager() users.Manager {
|
||||
return Create(s, func() users.Manager {
|
||||
return users.NewManager(s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) SettingsManager() settings.Manager {
|
||||
return Create(s, func() settings.Manager {
|
||||
extraSettingsManager := integrations.NewManager(s.EventStore())
|
||||
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) PeersManager() peers.Manager {
|
||||
return Create(s, func() peers.Manager {
|
||||
return peers.NewManager(s.Store(), s.PermissionsManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AccountManager() account.Manager {
|
||||
return Create(s, func() account.Manager {
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain,
|
||||
s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
}
|
||||
return accountManager
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return Create(s, func() idp.Manager {
|
||||
var idpManager idp.Manager
|
||||
var err error
|
||||
if s.config.IdpManagerConfig != nil {
|
||||
idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create IDP manager: %v", err)
|
||||
}
|
||||
}
|
||||
return idpManager
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) GroupsManager() groups.Manager {
|
||||
return Create(s, func() groups.Manager {
|
||||
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||
return Create(s, func() resources.Manager {
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RoutesManager() routers.Manager {
|
||||
return Create(s, func() routers.Manager {
|
||||
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) NetworksManager() networks.Manager {
|
||||
return Create(s, func() networks.Manager {
|
||||
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
340
management/internals/server/server.go
Normal file
340
management/internals/server/server.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||
// It is used for backward compatibility now.
|
||||
const ManagementLegacyPort = 33073
|
||||
|
||||
type Server interface {
|
||||
Start(ctx context.Context) error
|
||||
Stop() error
|
||||
Errors() <-chan error
|
||||
GetContainer(key string) (any, bool)
|
||||
SetContainer(key string, container any)
|
||||
}
|
||||
|
||||
// Server holds the HTTP BaseServer instance.
|
||||
// Add any additional fields you need, such as database connections, config, etc.
|
||||
type BaseServer struct {
|
||||
// config holds the server configuration
|
||||
config *nbconfig.Config
|
||||
// container of dependencies, each dependency is identified by a unique string.
|
||||
container map[string]any
|
||||
// AfterInit is a function that will be called after the server is initialized
|
||||
afterInit []func(s *BaseServer)
|
||||
|
||||
disableMetrics bool
|
||||
dnsDomain string
|
||||
disableGeoliteUpdate bool
|
||||
userDeleteFromIDPEnabled bool
|
||||
mgmtSingleAccModeDomain string
|
||||
mgmtMetricsPort int
|
||||
mgmtPort int
|
||||
|
||||
listener net.Listener
|
||||
certManager *autocert.Manager
|
||||
update *version.Update
|
||||
|
||||
errCh chan error
|
||||
wg sync.WaitGroup
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewServer initializes and configures a new Server instance
|
||||
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
|
||||
return &BaseServer{
|
||||
config: config,
|
||||
container: make(map[string]any),
|
||||
dnsDomain: dnsDomain,
|
||||
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
|
||||
disableMetrics: disableMetrics,
|
||||
disableGeoliteUpdate: disableGeoliteUpdate,
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
mgmtPort: mgmtPort,
|
||||
mgmtMetricsPort: mgmtMetricsPort,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {
|
||||
s.afterInit = append(s.afterInit, fn)
|
||||
}
|
||||
|
||||
// Start begins listening for HTTP requests on the configured address
|
||||
func (s *BaseServer) Start(ctx context.Context) error {
|
||||
srvCtx, cancel := context.WithCancel(ctx)
|
||||
s.cancel = cancel
|
||||
s.errCh = make(chan error, 4)
|
||||
|
||||
s.PeersManager()
|
||||
|
||||
for _, fn := range s.afterInit {
|
||||
if fn != nil {
|
||||
fn(s)
|
||||
}
|
||||
}
|
||||
|
||||
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to expose metrics: %v", err)
|
||||
}
|
||||
s.EphemeralManager().LoadInitialPeers(srvCtx)
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
tlsEnabled := false
|
||||
if s.config.HttpConfig.LetsEncryptDomain != "" {
|
||||
s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
|
||||
}
|
||||
tlsEnabled = true
|
||||
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
|
||||
if err != nil {
|
||||
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
|
||||
return err
|
||||
}
|
||||
tlsEnabled = true
|
||||
}
|
||||
|
||||
installationID, err := getInstallationID(srvCtx, s.Store())
|
||||
if err != nil {
|
||||
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.disableMetrics {
|
||||
idpManager := "disabled"
|
||||
if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" {
|
||||
idpManager = s.config.IdpManagerConfig.ManagerType
|
||||
}
|
||||
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
|
||||
go metricsWorker.Run(srvCtx)
|
||||
}
|
||||
|
||||
var compatListener net.Listener
|
||||
if s.mgmtPort != ManagementLegacyPort {
|
||||
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
|
||||
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
|
||||
compatListener, err = s.serveGRPC(srvCtx, s.GRPCServer(), ManagementLegacyPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler())
|
||||
switch {
|
||||
case s.certManager != nil:
|
||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||
cml := s.certManager.Listener()
|
||||
if s.mgmtPort == 443 {
|
||||
// CertManager, HTTP and gRPC API all on the same port
|
||||
rootHandler = s.certManager.HTTPHandler(rootHandler)
|
||||
s.listener = cml
|
||||
} else {
|
||||
s.listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort), s.certManager.TLSConfig())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating TLS listener on port %d: %v", s.mgmtPort, err)
|
||||
}
|
||||
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
|
||||
s.serveHTTP(ctx, cml, s.certManager.HTTPHandler(nil))
|
||||
}
|
||||
case tlsConfig != nil:
|
||||
s.listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort), tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating TLS listener on port %d: %v", s.mgmtPort, err)
|
||||
}
|
||||
default:
|
||||
s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating TCP listener on port %d: %v", s.mgmtPort, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
|
||||
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
|
||||
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
|
||||
|
||||
s.update = version.NewUpdate("nb/management")
|
||||
s.update.SetDaemonVersion(version.NetbirdVersion())
|
||||
s.update.SetOnUpdateListener(func() {
|
||||
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop attempts a graceful shutdown, waiting up to 5 seconds for active connections to finish
|
||||
func (s *BaseServer) Stop() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
s.IntegratedValidator().Stop(ctx)
|
||||
if s.GeoLocationManager() != nil {
|
||||
_ = s.GeoLocationManager().Stop()
|
||||
}
|
||||
s.EphemeralManager().Stop()
|
||||
_ = s.Metrics().Close()
|
||||
if s.listener != nil {
|
||||
_ = s.listener.Close()
|
||||
}
|
||||
if s.certManager != nil {
|
||||
_ = s.certManager.Listener().Close()
|
||||
}
|
||||
s.GRPCServer().Stop()
|
||||
_ = s.Store().Close(ctx)
|
||||
_ = s.EventStore().Close(ctx)
|
||||
if s.update != nil {
|
||||
s.update.StopWatch()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.Errors():
|
||||
log.WithContext(ctx).Infof("stopped Management Service")
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the server stops
|
||||
func (s *BaseServer) Errors() <-chan error {
|
||||
return s.errCh
|
||||
}
|
||||
|
||||
// GetContainer retrieves a dependency from the BaseServer's container by its key
|
||||
func (s *BaseServer) GetContainer(key string) (any, bool) {
|
||||
container, exists := s.container[key]
|
||||
return container, exists
|
||||
}
|
||||
|
||||
// SetContainer stores a dependency in the BaseServer's container with the specified key
|
||||
func (s *BaseServer) SetContainer(key string, container any) {
|
||||
if _, exists := s.container[key]; exists {
|
||||
log.Tracef("container with key %s already exists", key)
|
||||
return
|
||||
}
|
||||
s.container[key] = container
|
||||
log.Tracef("container with key %s set successfully", key)
|
||||
}
|
||||
|
||||
func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) error {
|
||||
return util.DirectWriteJson(ctx, path, config)
|
||||
}
|
||||
|
||||
func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
|
||||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")
|
||||
if request.ProtoMajor == 2 && grpcHeader {
|
||||
gRPCHandler.ServeHTTP(writer, request)
|
||||
} else {
|
||||
httpHandler.ServeHTTP(writer, request)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
err := grpcServer.Serve(listener)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case s.errCh <- err:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func (s *BaseServer) serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
err := http.Serve(httpListener, handler)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case s.errCh <- err:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
var err error
|
||||
if tlsEnabled {
|
||||
err = http.Serve(listener, handler)
|
||||
} else {
|
||||
// the following magic is needed to support HTTP2 without TLS
|
||||
// and still share a single port between gRPC and HTTP APIs
|
||||
h1s := &http.Server{
|
||||
Handler: h2c.NewHandler(handler, &http2.Server{}),
|
||||
}
|
||||
err = h1s.Serve(listener)
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case s.errCh <- err:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
|
||||
installationID := store.GetInstallationID()
|
||||
if installationID != "" {
|
||||
return installationID, nil
|
||||
}
|
||||
|
||||
installationID = strings.ToUpper(uuid.New().String())
|
||||
err := store.SaveInstallationID(ctx, installationID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return installationID, nil
|
||||
}
|
||||
Reference in New Issue
Block a user