mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
starting refactor
This commit is contained in:
@@ -0,0 +1,7 @@
|
|||||||
|
package access_control
|
||||||
|
|
||||||
|
type AccessControlManager interface {
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultAccessControlManager struct {
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
package access_control
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
package access_control
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// TrafficFlowType defines allowed direction of the traffic in the rule
|
||||||
|
type TrafficFlowType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TrafficFlowBidirect allows traffic to both direction
|
||||||
|
TrafficFlowBidirect TrafficFlowType = iota
|
||||||
|
// TrafficFlowBidirectString allows traffic to both direction
|
||||||
|
TrafficFlowBidirectString = "bidirect"
|
||||||
|
// DefaultRuleName is a name for the Default rule that is created for every account
|
||||||
|
DefaultRuleName = "Default"
|
||||||
|
// DefaultRuleDescription is a description for the Default rule that is created for every account
|
||||||
|
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
|
||||||
|
// DefaultPolicyName is a name for the Default policy that is created for every account
|
||||||
|
DefaultPolicyName = "Default"
|
||||||
|
// DefaultPolicyDescription is a description for the Default policy that is created for every account
|
||||||
|
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rule of ACL for groups
|
||||||
|
type Rule struct {
|
||||||
|
// ID of the rule
|
||||||
|
ID string
|
||||||
|
|
||||||
|
// AccountID is a reference to Account that this object belongs
|
||||||
|
AccountID string `json:"-" gorm:"index"`
|
||||||
|
|
||||||
|
// Name of the rule visible in the UI
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Description of the rule visible in the UI
|
||||||
|
Description string
|
||||||
|
|
||||||
|
// Disabled status of rule in the system
|
||||||
|
Disabled bool
|
||||||
|
|
||||||
|
// Source list of groups IDs of peers
|
||||||
|
Source []string `gorm:"serializer:json"`
|
||||||
|
|
||||||
|
// Destination list of groups IDs of peers
|
||||||
|
Destination []string `gorm:"serializer:json"`
|
||||||
|
|
||||||
|
// Flow of the traffic allowed by the rule
|
||||||
|
Flow TrafficFlowType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rule) Copy() *Rule {
|
||||||
|
rule := &Rule{
|
||||||
|
ID: r.ID,
|
||||||
|
Name: r.Name,
|
||||||
|
Description: r.Description,
|
||||||
|
Disabled: r.Disabled,
|
||||||
|
Source: make([]string, len(r.Source)),
|
||||||
|
Destination: make([]string, len(r.Destination)),
|
||||||
|
Flow: r.Flow,
|
||||||
|
}
|
||||||
|
copy(rule.Source, r.Source)
|
||||||
|
copy(rule.Destination, r.Destination)
|
||||||
|
return rule
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventMeta returns activity event meta related to this rule
|
||||||
|
func (r *Rule) EventMeta() map[string]any {
|
||||||
|
return map[string]any{"name": r.Name}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToPolicyRule converts a Rule to a PolicyRule object
|
||||||
|
func (r *Rule) ToPolicyRule() *PolicyRule {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &PolicyRule{
|
||||||
|
ID: r.ID,
|
||||||
|
Name: r.Name,
|
||||||
|
Enabled: !r.Disabled,
|
||||||
|
Description: r.Description,
|
||||||
|
Destinations: r.Destination,
|
||||||
|
Sources: r.Source,
|
||||||
|
Bidirectional: true,
|
||||||
|
Protocol: PolicyRuleProtocolALL,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RuleToPolicy converts a Rule to a Policy query object
|
||||||
|
func RuleToPolicy(rule *Rule) (*Policy, error) {
|
||||||
|
if rule == nil {
|
||||||
|
return nil, fmt.Errorf("rule is empty")
|
||||||
|
}
|
||||||
|
return &Policy{
|
||||||
|
ID: rule.ID,
|
||||||
|
Name: rule.Name,
|
||||||
|
Description: rule.Description,
|
||||||
|
Enabled: !rule.Disabled,
|
||||||
|
Rules: []*PolicyRule{rule.ToPolicyRule()},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package accounts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Settings represents Account settings structure that can be modified via API and Dashboard
|
||||||
|
type Settings struct {
|
||||||
|
// PeerLoginExpirationEnabled globally enables or disables peer login expiration
|
||||||
|
PeerLoginExpirationEnabled bool
|
||||||
|
|
||||||
|
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
||||||
|
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||||
|
PeerLoginExpiration time.Duration
|
||||||
|
|
||||||
|
// GroupsPropagationEnabled allows to propagate auto groups from the user to the peer
|
||||||
|
GroupsPropagationEnabled bool
|
||||||
|
|
||||||
|
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
|
||||||
|
// and add it to account groups.
|
||||||
|
JWTGroupsEnabled bool
|
||||||
|
|
||||||
|
// JWTGroupsClaimName from which we extract groups name to add it to account groups
|
||||||
|
JWTGroupsClaimName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy copies the Settings struct
|
||||||
|
func (s *Settings) Copy() *Settings {
|
||||||
|
return &Settings{
|
||||||
|
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||||
|
PeerLoginExpiration: s.PeerLoginExpiration,
|
||||||
|
JWTGroupsEnabled: s.JWTGroupsEnabled,
|
||||||
|
JWTGroupsClaimName: s.JWTGroupsClaimName,
|
||||||
|
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Account represents a unique account of the system
|
||||||
|
type Account struct {
|
||||||
|
// we have to name column to aid as it collides with Network.Id when work with associations
|
||||||
|
Id string `gorm:"primaryKey"`
|
||||||
|
|
||||||
|
// User.Id it was created by
|
||||||
|
CreatedBy string
|
||||||
|
Domain string `gorm:"index"`
|
||||||
|
DomainCategory string
|
||||||
|
IsDomainPrimaryAccount bool
|
||||||
|
SetupKeys map[string]*SetupKey `gorm:"-"`
|
||||||
|
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||||
|
Peers map[string]*Peer `gorm:"-"`
|
||||||
|
PeersG []Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
Users map[string]*User `gorm:"-"`
|
||||||
|
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
Groups map[string]*Group `gorm:"-"`
|
||||||
|
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
Rules map[string]*Rule `gorm:"-"`
|
||||||
|
RulesG []Rule `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
Routes map[string]*route.Route `gorm:"-"`
|
||||||
|
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
|
||||||
|
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||||
|
// Settings is a dictionary of Account settings
|
||||||
|
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package accounts
|
||||||
|
|
||||||
|
type AccountManager interface {
|
||||||
|
GetAccount(accountID string) (Account, error)
|
||||||
|
GetDNSDomain() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultAccountManager struct {
|
||||||
|
repository AccountRepository
|
||||||
|
|
||||||
|
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||||
|
dnsDomain string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) GetAccount(accountID string) (Account, error) {
|
||||||
|
return am.repository.findAccountByID(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDNSDomain returns the configured dnsDomain
|
||||||
|
func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||||
|
return am.dnsDomain
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
package accounts
|
||||||
|
|
||||||
|
type AccountRepository interface {
|
||||||
|
findAccountByID(accountID string) (Account, error)
|
||||||
|
}
|
||||||
150
management/server/management_refactor/server/config.go
Normal file
150
management/server/management_refactor/server/config.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// Protocol type
|
||||||
|
Protocol string
|
||||||
|
|
||||||
|
// Provider authorization flow type
|
||||||
|
Provider string
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
UDP Protocol = "udp"
|
||||||
|
DTLS Protocol = "dtls"
|
||||||
|
TCP Protocol = "tcp"
|
||||||
|
HTTP Protocol = "http"
|
||||||
|
HTTPS Protocol = "https"
|
||||||
|
NONE Provider = "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultDeviceAuthFlowScope defines the bare minimum scope to request in the device authorization flow
|
||||||
|
DefaultDeviceAuthFlowScope string = "openid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config of the Management service
|
||||||
|
type Config struct {
|
||||||
|
Stuns []*Host
|
||||||
|
TURNConfig *TURNConfig
|
||||||
|
Signal *Host
|
||||||
|
|
||||||
|
Datadir string
|
||||||
|
DataStoreEncryptionKey string
|
||||||
|
|
||||||
|
HttpConfig *HttpServerConfig
|
||||||
|
|
||||||
|
IdpManagerConfig *idp.Config
|
||||||
|
|
||||||
|
DeviceAuthorizationFlow *DeviceAuthorizationFlow
|
||||||
|
|
||||||
|
PKCEAuthorizationFlow *PKCEAuthorizationFlow
|
||||||
|
|
||||||
|
StoreConfig StoreConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthAudiences returns the audience from the http config and device authorization flow config
|
||||||
|
func (c Config) GetAuthAudiences() []string {
|
||||||
|
audiences := []string{c.HttpConfig.AuthAudience}
|
||||||
|
|
||||||
|
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
|
||||||
|
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
|
||||||
|
}
|
||||||
|
|
||||||
|
return audiences
|
||||||
|
}
|
||||||
|
|
||||||
|
// TURNConfig is a config of the TURNCredentialsManager
|
||||||
|
type TURNConfig struct {
|
||||||
|
TimeBasedCredentials bool
|
||||||
|
CredentialsTTL util.Duration
|
||||||
|
Secret string
|
||||||
|
Turns []*Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// HttpServerConfig is a config of the HTTP Management service server
|
||||||
|
type HttpServerConfig struct {
|
||||||
|
LetsEncryptDomain string
|
||||||
|
// CertFile is the location of the certificate
|
||||||
|
CertFile string
|
||||||
|
// CertKey is the location of the certificate private key
|
||||||
|
CertKey string
|
||||||
|
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
||||||
|
AuthAudience string
|
||||||
|
// AuthIssuer identifies principal that issued the JWT
|
||||||
|
AuthIssuer string
|
||||||
|
// AuthUserIDClaim is the name of the claim that used as user ID
|
||||||
|
AuthUserIDClaim string
|
||||||
|
// AuthKeysLocation is a location of JWT key set containing the public keys used to verify JWT
|
||||||
|
AuthKeysLocation string
|
||||||
|
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
|
||||||
|
OIDCConfigEndpoint string
|
||||||
|
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
|
||||||
|
IdpSignKeyRefreshEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)
|
||||||
|
type Host struct {
|
||||||
|
Proto Protocol
|
||||||
|
// URI e.g. turns://stun.wiretrustee.com:4430 or signal.wiretrustee.com:10000
|
||||||
|
URI string
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||||
|
// that can be used by the client to login initiate a Oauth 2.0 device authorization grant flow
|
||||||
|
// see https://datatracker.ietf.org/doc/html/rfc8628
|
||||||
|
type DeviceAuthorizationFlow struct {
|
||||||
|
Provider string
|
||||||
|
ProviderConfig ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCEAuthorizationFlow represents Authorization Code Flow information
|
||||||
|
// that can be used by the client to login initiate a Oauth 2.0 authorization code grant flow
|
||||||
|
// with Proof Key for Code Exchange (PKCE). See https://datatracker.ietf.org/doc/html/rfc7636
|
||||||
|
type PKCEAuthorizationFlow struct {
|
||||||
|
ProviderConfig ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderConfig has all attributes needed to initiate a device/pkce authorization flow
|
||||||
|
type ProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Domain An IDP API domain
|
||||||
|
// Deprecated. Use TokenEndpoint and DeviceAuthEndpoint
|
||||||
|
Domain string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||||
|
DeviceAuthEndpoint string
|
||||||
|
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||||
|
AuthorizationEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
// RedirectURL handles authorization code from IDP manager
|
||||||
|
RedirectURLs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreConfig contains Store configuration
|
||||||
|
type StoreConfig struct {
|
||||||
|
Engine StoreEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateURL validates input http url
|
||||||
|
func validateURL(httpURL string) bool {
|
||||||
|
_, err := url.ParseRequestURI(httpURL)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
1
management/server/management_refactor/server/dns/dns.go
Normal file
1
management/server/management_refactor/server/dns/dns.go
Normal file
@@ -0,0 +1 @@
|
|||||||
|
package dns
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
package events
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EventsManager struct {
|
||||||
|
store activity.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEventsManager(store activity.Store) *EventsManager {
|
||||||
|
return &EventsManager{
|
||||||
|
store: store,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEvents returns a list of activity events of an account
|
||||||
|
func (em *EventsManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
|
||||||
|
events, err := em.store.Get(accountID, 0, 10000, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is a workaround for duplicate activity.UserJoined events that might occur when a user redeems invite.
|
||||||
|
// we will need to find a better way to handle this.
|
||||||
|
filtered := make([]*activity.Event, 0)
|
||||||
|
dups := make(map[string]struct{})
|
||||||
|
for _, event := range events {
|
||||||
|
if event.Activity == activity.UserJoined {
|
||||||
|
key := event.TargetID + event.InitiatorID + event.AccountID + fmt.Sprint(event.Activity)
|
||||||
|
_, duplicate := dups[key]
|
||||||
|
if duplicate {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
dups[key] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered = append(filtered, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filtered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (em *EventsManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity,
|
||||||
|
meta map[string]any) {
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := em.store.Save(&activity.Event{
|
||||||
|
Timestamp: time.Now().UTC(),
|
||||||
|
Activity: activityID,
|
||||||
|
InitiatorID: initiatorID,
|
||||||
|
TargetID: targetID,
|
||||||
|
AccountID: accountID,
|
||||||
|
Meta: meta,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// todo add metric
|
||||||
|
log.Errorf("received an error while storing an activity event, error: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package events
|
||||||
|
|
||||||
|
// import (
|
||||||
|
// "testing"
|
||||||
|
// "time"
|
||||||
|
//
|
||||||
|
// "github.com/stretchr/testify/assert"
|
||||||
|
//
|
||||||
|
// "github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// func generateAndStoreEvents(t *testing.T, manager *EventsManager, typ activity.Activity, initiatorID, targetID,
|
||||||
|
// accountID string, count int) {
|
||||||
|
// t.Helper()
|
||||||
|
// for i := 0; i < count; i++ {
|
||||||
|
// _, err := manager.store.Save(&activity.Event{
|
||||||
|
// Timestamp: time.Now().UTC(),
|
||||||
|
// Activity: typ,
|
||||||
|
// InitiatorID: initiatorID,
|
||||||
|
// TargetID: targetID,
|
||||||
|
// AccountID: accountID,
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// t.Fatal(err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// func TestDefaultAccountManager_GetEvents(t *testing.T) {
|
||||||
|
// manager, err := createManager(t)
|
||||||
|
// if err != nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// accountID := "accountID"
|
||||||
|
//
|
||||||
|
// t.Run("get empty events list", func(t *testing.T) {
|
||||||
|
// events, err := manager.GetEvents(accountID, userID)
|
||||||
|
// if err != nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// assert.Len(t, events, 0)
|
||||||
|
// _ = manager.eventStore.Close() //nolint
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// t.Run("get events", func(t *testing.T) {
|
||||||
|
// generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10)
|
||||||
|
// events, err := manager.GetEvents(accountID, userID)
|
||||||
|
// if err != nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// assert.Len(t, events, 10)
|
||||||
|
// _ = manager.eventStore.Close() //nolint
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// t.Run("get events without duplicates", func(t *testing.T) {
|
||||||
|
// generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10)
|
||||||
|
// events, err := manager.GetEvents(accountID, userID)
|
||||||
|
// if err != nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// assert.Len(t, events, 1)
|
||||||
|
// _ = manager.eventStore.Close() //nolint
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// func createManager(t *testing.T) (*EventsManager, error) {
|
||||||
|
// t.Helper()
|
||||||
|
// store, err := createStore(t)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// eventStore := &activity.InMemoryEventStore{}
|
||||||
|
// return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, false)
|
||||||
|
// }
|
||||||
333
management/server/management_refactor/server/groups/group.go
Normal file
333
management/server/management_refactor/server/groups/group.go
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
package groups
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GroupLinkError struct {
|
||||||
|
Resource string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *GroupLinkError) Error() string {
|
||||||
|
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group of the peers for ACL
|
||||||
|
type Group struct {
|
||||||
|
// ID of the group
|
||||||
|
ID string
|
||||||
|
|
||||||
|
// AccountID is a reference to Account that this object belongs
|
||||||
|
AccountID string `json:"-" gorm:"index"`
|
||||||
|
|
||||||
|
// Name visible in the UI
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Issued of the group
|
||||||
|
Issued string
|
||||||
|
|
||||||
|
// Peers list of the group
|
||||||
|
Peers []string `gorm:"serializer:json"`
|
||||||
|
|
||||||
|
IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventMeta returns activity event meta related to the group
|
||||||
|
func (g *Group) EventMeta() map[string]any {
|
||||||
|
return map[string]any{"name": g.Name}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Group) Copy() *Group {
|
||||||
|
group := &Group{
|
||||||
|
ID: g.ID,
|
||||||
|
Name: g.Name,
|
||||||
|
Issued: g.Issued,
|
||||||
|
Peers: make([]string, len(g.Peers)),
|
||||||
|
IntegrationReference: g.IntegrationReference,
|
||||||
|
}
|
||||||
|
copy(group.Peers, g.Peers)
|
||||||
|
return group
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroup object of the peers
|
||||||
|
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
group, ok := account.Groups[groupID]
|
||||||
|
if ok {
|
||||||
|
return group, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveGroup object of the peers
|
||||||
|
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *Group) error {
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
oldGroup, exists := account.Groups[newGroup.ID]
|
||||||
|
account.Groups[newGroup.ID] = newGroup
|
||||||
|
|
||||||
|
account.Network.IncSerial()
|
||||||
|
if err = am.Store.SaveAccount(account); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
// the following snippet tracks the activity and stores the group events in the event store.
|
||||||
|
// It has to happen after all the operations have been successfully performed.
|
||||||
|
addedPeers := make([]string, 0)
|
||||||
|
removedPeers := make([]string, 0)
|
||||||
|
if exists {
|
||||||
|
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||||
|
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||||
|
} else {
|
||||||
|
addedPeers = append(addedPeers, newGroup.Peers...)
|
||||||
|
am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range addedPeers {
|
||||||
|
peer := account.Peers[p]
|
||||||
|
if peer == nil {
|
||||||
|
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer,
|
||||||
|
map[string]any{
|
||||||
|
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
||||||
|
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range removedPeers {
|
||||||
|
peer := account.Peers[p]
|
||||||
|
if peer == nil {
|
||||||
|
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
|
||||||
|
map[string]any{
|
||||||
|
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
||||||
|
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// difference returns the elements in `a` that aren't in `b`.
|
||||||
|
func difference(a, b []string) []string {
|
||||||
|
mb := make(map[string]struct{}, len(b))
|
||||||
|
for _, x := range b {
|
||||||
|
mb[x] = struct{}{}
|
||||||
|
}
|
||||||
|
var diff []string
|
||||||
|
for _, x := range a {
|
||||||
|
if _, found := mb[x]; !found {
|
||||||
|
diff = append(diff, x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return diff
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteGroup object of the peers
|
||||||
|
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountId)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
g, ok := account.Groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// disable a deleting integration group if the initiator is not an admin service user
|
||||||
|
if g.Issued == GroupIssuedIntegration {
|
||||||
|
executingUser := account.Users[userId]
|
||||||
|
if executingUser == nil {
|
||||||
|
return status.Errorf(status.NotFound, "user not found")
|
||||||
|
}
|
||||||
|
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||||
|
return status.Errorf(status.PermissionDenied, "only admins service user can delete integration group")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check route links
|
||||||
|
for _, r := range account.Routes {
|
||||||
|
for _, g := range r.Groups {
|
||||||
|
if g == groupID {
|
||||||
|
return &GroupLinkError{"route", r.NetID}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check DNS links
|
||||||
|
for _, dns := range account.NameServerGroups {
|
||||||
|
for _, g := range dns.Groups {
|
||||||
|
if g == groupID {
|
||||||
|
return &GroupLinkError{"name server groups", dns.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check ACL links
|
||||||
|
for _, policy := range account.Policies {
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
for _, src := range rule.Sources {
|
||||||
|
if src == groupID {
|
||||||
|
return &GroupLinkError{"policy", policy.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dst := range rule.Destinations {
|
||||||
|
if dst == groupID {
|
||||||
|
return &GroupLinkError{"policy", policy.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check setup key links
|
||||||
|
for _, setupKey := range account.SetupKeys {
|
||||||
|
for _, grp := range setupKey.AutoGroups {
|
||||||
|
if grp == groupID {
|
||||||
|
return &GroupLinkError{"setup key", setupKey.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check user links
|
||||||
|
for _, user := range account.Users {
|
||||||
|
for _, grp := range user.AutoGroups {
|
||||||
|
if grp == groupID {
|
||||||
|
return &GroupLinkError{"user", user.Id}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check DisabledManagementGroups
|
||||||
|
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
|
||||||
|
if disabledMgmGrp == groupID {
|
||||||
|
return &GroupLinkError{"disabled DNS management groups", g.Name}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(account.Groups, groupID)
|
||||||
|
|
||||||
|
account.Network.IncSerial()
|
||||||
|
if err = am.Store.SaveAccount(account); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||||
|
|
||||||
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListGroups objects of the peers
|
||||||
|
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) {
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := make([]*Group, 0, len(account.Groups))
|
||||||
|
for _, item := range account.Groups {
|
||||||
|
groups = append(groups, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
return groups, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupAddPeer appends peer to the group
|
||||||
|
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
group, ok := account.Groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
add := true
|
||||||
|
for _, itemID := range group.Peers {
|
||||||
|
if itemID == peerID {
|
||||||
|
add = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if add {
|
||||||
|
group.Peers = append(group.Peers, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Network.IncSerial()
|
||||||
|
if err = am.Store.SaveAccount(account); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupDeletePeer removes peer from the group
|
||||||
|
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
group, ok := account.Groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Network.IncSerial()
|
||||||
|
for i, itemID := range group.Peers {
|
||||||
|
if itemID == peerID {
|
||||||
|
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||||
|
if err := am.Store.SaveAccount(account); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
607
management/server/management_refactor/server/grpcserver.go
Normal file
607
management/server/management_refactor/server/grpcserver.go
Normal file
@@ -0,0 +1,607 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pb "github.com/golang/protobuf/proto" // nolint
|
||||||
|
"github.com/golang/protobuf/ptypes/timestamp"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
gRPCPeer "google.golang.org/grpc/peer"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/encryption"
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
"github.com/netbirdio/netbird/management/server/management_refactor/server/accounts"
|
||||||
|
"github.com/netbirdio/netbird/management/server/management_refactor/server/peers"
|
||||||
|
internalStatus "github.com/netbirdio/netbird/management/server/status"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GRPCServer an instance of a Management gRPC API server
|
||||||
|
type GRPCServer struct {
|
||||||
|
accountManager accounts.AccountManager
|
||||||
|
wgKey wgtypes.Key
|
||||||
|
proto.UnimplementedManagementServiceServer
|
||||||
|
peersUpdateManager *peers.PeersUpdateManager
|
||||||
|
config *Config
|
||||||
|
turnCredentialsManager TURNCredentialsManager
|
||||||
|
jwtValidator *jwtclaims.JWTValidator
|
||||||
|
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
|
appMetrics telemetry.AppMetrics
|
||||||
|
ephemeralManager *peers.EphemeralManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new Management server
|
||||||
|
func NewServer(config *Config, accountManager accounts.AccountManager, peersUpdateManager *peers.PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *peers.EphemeralManager) (*GRPCServer, error) {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var jwtValidator *jwtclaims.JWTValidator
|
||||||
|
|
||||||
|
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
||||||
|
jwtValidator, err = jwtclaims.NewJWTValidator(
|
||||||
|
config.HttpConfig.AuthIssuer,
|
||||||
|
config.GetAuthAudiences(),
|
||||||
|
config.HttpConfig.AuthKeysLocation,
|
||||||
|
config.HttpConfig.IdpSignKeyRefreshEnabled,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debug("unable to use http config to create new jwt middleware")
|
||||||
|
}
|
||||||
|
|
||||||
|
if appMetrics != nil {
|
||||||
|
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||||
|
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||||
|
return int64(len(peersUpdateManager.PeerChannels))
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var audience, userIDClaim string
|
||||||
|
if config.HttpConfig != nil {
|
||||||
|
audience = config.HttpConfig.AuthAudience
|
||||||
|
userIDClaim = config.HttpConfig.AuthUserIDClaim
|
||||||
|
}
|
||||||
|
jwtClaimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||||
|
jwtclaims.WithAudience(audience),
|
||||||
|
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||||
|
)
|
||||||
|
|
||||||
|
return &GRPCServer{
|
||||||
|
wgKey: key,
|
||||||
|
// peerKey -> event channel
|
||||||
|
peersUpdateManager: peersUpdateManager,
|
||||||
|
accountManager: accountManager,
|
||||||
|
config: config,
|
||||||
|
turnCredentialsManager: turnCredentialsManager,
|
||||||
|
jwtValidator: jwtValidator,
|
||||||
|
jwtClaimsExtractor: jwtClaimsExtractor,
|
||||||
|
appMetrics: appMetrics,
|
||||||
|
ephemeralManager: ephemeralManager,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||||
|
// todo introduce something more meaningful with the key expiration/rotation
|
||||||
|
if s.appMetrics != nil {
|
||||||
|
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
|
||||||
|
}
|
||||||
|
now := time.Now().Add(24 * time.Hour)
|
||||||
|
secs := int64(now.Second())
|
||||||
|
nanos := int32(now.Nanosecond())
|
||||||
|
expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos}
|
||||||
|
|
||||||
|
return &proto.ServerKeyResponse{
|
||||||
|
Key: s.wgKey.PublicKey().String(),
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||||
|
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||||
|
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
|
reqStart := time.Now()
|
||||||
|
if s.appMetrics != nil {
|
||||||
|
s.appMetrics.GRPCMetrics().CountSyncRequest()
|
||||||
|
}
|
||||||
|
p, ok := gRPCPeer.FromContext(srv.Context())
|
||||||
|
if ok {
|
||||||
|
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
syncReq := &proto.SyncRequest{}
|
||||||
|
peerKey, err := s.parseRequest(req, syncReq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
|
||||||
|
if err != nil {
|
||||||
|
return mapError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.sendInitialSync(peerKey, peer, netMap, srv)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updates := s.peersUpdateManager.CreateChannel(peer.ID)
|
||||||
|
|
||||||
|
s.ephemeralManager.OnPeerConnected(peer)
|
||||||
|
|
||||||
|
err = s.accountManager.MarkPeerConnected(peerKey.String(), true)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.config.TURNConfig.TimeBasedCredentials {
|
||||||
|
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.appMetrics != nil {
|
||||||
|
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
// keep a connection to the peer and send updates when available
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// condition when there are some updates
|
||||||
|
case update, open := <-updates:
|
||||||
|
|
||||||
|
if s.appMetrics != nil {
|
||||||
|
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !open {
|
||||||
|
log.Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||||
|
s.cancelPeerRoutines(peer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Debugf("received an update for peer %s", peerKey.String())
|
||||||
|
|
||||||
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||||
|
if err != nil {
|
||||||
|
s.cancelPeerRoutines(peer)
|
||||||
|
return status.Errorf(codes.Internal, "failed processing update message")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||||
|
WgPubKey: s.wgKey.PublicKey().String(),
|
||||||
|
Body: encryptedResp,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
s.cancelPeerRoutines(peer)
|
||||||
|
return status.Errorf(codes.Internal, "failed sending update message")
|
||||||
|
}
|
||||||
|
log.Debugf("sent an update to peer %s", peerKey.String())
|
||||||
|
// condition when client <-> server connection has been terminated
|
||||||
|
case <-srv.Context().Done():
|
||||||
|
// happens when connection drops, e.g. client disconnects
|
||||||
|
log.Debugf("stream of peer %s has been closed", peerKey.String())
|
||||||
|
s.cancelPeerRoutines(peer)
|
||||||
|
return srv.Context().Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GRPCServer) cancelPeerRoutines(peer *Peer) {
|
||||||
|
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||||
|
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||||
|
_ = s.accountManager.MarkPeerConnected(peer.Key, false)
|
||||||
|
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
||||||
|
if s.jwtValidator == nil {
|
||||||
|
return "", status.Error(codes.Internal, "no jwt validator set")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
||||||
|
}
|
||||||
|
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||||
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||||
|
_, _, err = s.accountManager.GetAccountFromToken(claims)
|
||||||
|
if err != nil {
|
||||||
|
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims.UserId, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// maps internal internalStatus.Error to gRPC status.Error
|
||||||
|
func mapError(err error) error {
|
||||||
|
if e, ok := internalStatus.FromError(err); ok {
|
||||||
|
switch e.Type() {
|
||||||
|
case internalStatus.PermissionDenied:
|
||||||
|
return status.Errorf(codes.PermissionDenied, e.Message)
|
||||||
|
case internalStatus.Unauthorized:
|
||||||
|
return status.Errorf(codes.PermissionDenied, e.Message)
|
||||||
|
case internalStatus.Unauthenticated:
|
||||||
|
return status.Errorf(codes.PermissionDenied, e.Message)
|
||||||
|
case internalStatus.PreconditionFailed:
|
||||||
|
return status.Errorf(codes.FailedPrecondition, e.Message)
|
||||||
|
case internalStatus.NotFound:
|
||||||
|
return status.Errorf(codes.NotFound, e.Message)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Errorf("got an unhandled error: %s", err)
|
||||||
|
return status.Errorf(codes.Internal, "failed handling request")
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractPeerMeta(loginReq *proto.LoginRequest) PeerSystemMeta {
|
||||||
|
return PeerSystemMeta{
|
||||||
|
Hostname: loginReq.GetMeta().GetHostname(),
|
||||||
|
GoOS: loginReq.GetMeta().GetGoOS(),
|
||||||
|
Kernel: loginReq.GetMeta().GetKernel(),
|
||||||
|
Core: loginReq.GetMeta().GetCore(),
|
||||||
|
Platform: loginReq.GetMeta().GetPlatform(),
|
||||||
|
OS: loginReq.GetMeta().GetOS(),
|
||||||
|
WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(),
|
||||||
|
UIVersion: loginReq.GetMeta().GetUiVersion(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||||
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
|
||||||
|
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
|
||||||
|
if err != nil {
|
||||||
|
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login endpoint first checks whether peer is registered under any account
|
||||||
|
// In case it is, the login is successful
|
||||||
|
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||||
|
// In case of the successful registration login is also successful
|
||||||
|
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
reqStart := time.Now()
|
||||||
|
defer func() {
|
||||||
|
if s.appMetrics != nil {
|
||||||
|
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if s.appMetrics != nil {
|
||||||
|
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||||
|
}
|
||||||
|
p, ok := gRPCPeer.FromContext(ctx)
|
||||||
|
if ok {
|
||||||
|
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
loginReq := &proto.LoginRequest{}
|
||||||
|
peerKey, err := s.parseRequest(req, loginReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginReq.GetMeta() == nil {
|
||||||
|
msg := status.Errorf(codes.FailedPrecondition,
|
||||||
|
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(),
|
||||||
|
p.Addr.String())
|
||||||
|
log.Warn(msg)
|
||||||
|
return nil, msg
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := ""
|
||||||
|
// JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered,
|
||||||
|
// or it uses a setup key to register.
|
||||||
|
if loginReq.GetJwtToken() != "" {
|
||||||
|
userID, err = s.validateToken(loginReq.GetJwtToken())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed validating JWT token sent from peer %s", peerKey)
|
||||||
|
return nil, mapError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var sshKey []byte
|
||||||
|
if loginReq.GetPeerKeys() != nil {
|
||||||
|
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
|
||||||
|
WireGuardPubKey: peerKey.String(),
|
||||||
|
SSHKey: string(sshKey),
|
||||||
|
Meta: extractPeerMeta(loginReq),
|
||||||
|
UserID: userID,
|
||||||
|
SetupKey: loginReq.GetSetupKey(),
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed logging in peer %s", peerKey)
|
||||||
|
return nil, mapError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the login request contains setup key then it is a registration request
|
||||||
|
if loginReq.GetSetupKey() != "" {
|
||||||
|
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if peer has reached this point then it has logged in
|
||||||
|
loginResp := &proto.LoginResponse{
|
||||||
|
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
|
||||||
|
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
|
||||||
|
}
|
||||||
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed encrypting peer %s message", peer.ID)
|
||||||
|
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.EncryptedMessage{
|
||||||
|
WgPubKey: s.wgKey.PublicKey().String(),
|
||||||
|
Body: encryptedResp,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
||||||
|
switch configProto {
|
||||||
|
case UDP:
|
||||||
|
return proto.HostConfig_UDP
|
||||||
|
case DTLS:
|
||||||
|
return proto.HostConfig_DTLS
|
||||||
|
case HTTP:
|
||||||
|
return proto.HostConfig_HTTP
|
||||||
|
case HTTPS:
|
||||||
|
return proto.HostConfig_HTTPS
|
||||||
|
case TCP:
|
||||||
|
return proto.HostConfig_TCP
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
|
||||||
|
if config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var stuns []*proto.HostConfig
|
||||||
|
for _, stun := range config.Stuns {
|
||||||
|
stuns = append(stuns, &proto.HostConfig{
|
||||||
|
Uri: stun.URI,
|
||||||
|
Protocol: ToResponseProto(stun.Proto),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
var turns []*proto.ProtectedHostConfig
|
||||||
|
for _, turn := range config.TURNConfig.Turns {
|
||||||
|
var username string
|
||||||
|
var password string
|
||||||
|
if turnCredentials != nil {
|
||||||
|
username = turnCredentials.Username
|
||||||
|
password = turnCredentials.Password
|
||||||
|
} else {
|
||||||
|
username = turn.Username
|
||||||
|
password = turn.Password
|
||||||
|
}
|
||||||
|
turns = append(turns, &proto.ProtectedHostConfig{
|
||||||
|
HostConfig: &proto.HostConfig{
|
||||||
|
Uri: turn.URI,
|
||||||
|
Protocol: ToResponseProto(turn.Proto),
|
||||||
|
},
|
||||||
|
User: username,
|
||||||
|
Password: password,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.WiretrusteeConfig{
|
||||||
|
Stuns: stuns,
|
||||||
|
Turns: turns,
|
||||||
|
Signal: &proto.HostConfig{
|
||||||
|
Uri: config.Signal.URI,
|
||||||
|
Protocol: ToResponseProto(config.Signal.Proto),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toPeerConfig(peer *Peer, network *Network, dnsName string) *proto.PeerConfig {
|
||||||
|
netmask, _ := network.Net.Mask.Size()
|
||||||
|
fqdn := peer.FQDN(dnsName)
|
||||||
|
return &proto.PeerConfig{
|
||||||
|
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
|
||||||
|
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
|
||||||
|
Fqdn: fqdn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toRemotePeerConfig(peers []*Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||||
|
remotePeers := []*proto.RemotePeerConfig{}
|
||||||
|
for _, rPeer := range peers {
|
||||||
|
fqdn := rPeer.FQDN(dnsName)
|
||||||
|
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
|
||||||
|
WgPubKey: rPeer.Key,
|
||||||
|
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
|
||||||
|
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||||
|
Fqdn: fqdn,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return remotePeers
|
||||||
|
}
|
||||||
|
|
||||||
|
func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
|
||||||
|
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
||||||
|
|
||||||
|
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
||||||
|
|
||||||
|
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
|
||||||
|
|
||||||
|
routesUpdate := toProtocolRoutes(networkMap.Routes)
|
||||||
|
|
||||||
|
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
|
||||||
|
|
||||||
|
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
|
||||||
|
|
||||||
|
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||||
|
|
||||||
|
return &proto.SyncResponse{
|
||||||
|
WiretrusteeConfig: wtConfig,
|
||||||
|
PeerConfig: pConfig,
|
||||||
|
RemotePeers: remotePeers,
|
||||||
|
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||||
|
NetworkMap: &proto.NetworkMap{
|
||||||
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
|
PeerConfig: pConfig,
|
||||||
|
RemotePeers: remotePeers,
|
||||||
|
OfflinePeers: offlinePeers,
|
||||||
|
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||||
|
Routes: routesUpdate,
|
||||||
|
DNSConfig: dnsUpdate,
|
||||||
|
FirewallRules: firewallRules,
|
||||||
|
FirewallRulesIsEmpty: len(firewallRules) == 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHealthy indicates whether the service is healthy
|
||||||
|
func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
|
||||||
|
return &proto.Empty{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||||
|
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
|
||||||
|
// make secret time based TURN credentials optional
|
||||||
|
var turnCredentials *TURNCredentials
|
||||||
|
if s.config.TURNConfig.TimeBasedCredentials {
|
||||||
|
creds := s.turnCredentialsManager.GenerateCredentials()
|
||||||
|
turnCredentials = &creds
|
||||||
|
} else {
|
||||||
|
turnCredentials = nil
|
||||||
|
}
|
||||||
|
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
|
||||||
|
|
||||||
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(codes.Internal, "error handling request")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = srv.Send(&proto.EncryptedMessage{
|
||||||
|
WgPubKey: s.wgKey.PublicKey().String(),
|
||||||
|
Body: encryptedResp,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed sending SyncResponse %v", err)
|
||||||
|
return status.Errorf(codes.Internal, "error handling request")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceAuthorizationFlow returns a device authorization flow information
|
||||||
|
// This is used for initiating an Oauth 2 device authorization grant flow
|
||||||
|
// which will be used by our clients to Login
|
||||||
|
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
|
if err != nil {
|
||||||
|
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
||||||
|
log.Warn(errMSG)
|
||||||
|
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
|
||||||
|
if err != nil {
|
||||||
|
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||||
|
log.Warn(errMSG)
|
||||||
|
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) {
|
||||||
|
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
flowInfoResp := &proto.DeviceAuthorizationFlow{
|
||||||
|
Provider: proto.DeviceAuthorizationFlowProvider(provider),
|
||||||
|
ProviderConfig: &proto.ProviderConfig{
|
||||||
|
ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID,
|
||||||
|
ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret,
|
||||||
|
Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain,
|
||||||
|
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
|
||||||
|
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
|
||||||
|
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
|
||||||
|
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
|
||||||
|
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.EncryptedMessage{
|
||||||
|
WgPubKey: s.wgKey.PublicKey().String(),
|
||||||
|
Body: encryptedResp,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
|
||||||
|
// This is used for initiating an Oauth 2 pkce authorization grant flow
|
||||||
|
// which will be used by our clients to Login
|
||||||
|
func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
|
if err != nil {
|
||||||
|
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
|
||||||
|
log.Warn(errMSG)
|
||||||
|
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
|
||||||
|
if err != nil {
|
||||||
|
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||||
|
log.Warn(errMSG)
|
||||||
|
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.config.PKCEAuthorizationFlow == nil {
|
||||||
|
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
|
||||||
|
}
|
||||||
|
|
||||||
|
flowInfoResp := &proto.PKCEAuthorizationFlow{
|
||||||
|
ProviderConfig: &proto.ProviderConfig{
|
||||||
|
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
|
||||||
|
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
|
||||||
|
ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret,
|
||||||
|
TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint,
|
||||||
|
AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint,
|
||||||
|
Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope,
|
||||||
|
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
|
||||||
|
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.EncryptedMessage{
|
||||||
|
WgPubKey: s.wgKey.PublicKey().String(),
|
||||||
|
Body: encryptedResp,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,286 @@
|
|||||||
|
package nameservers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"regexp"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||||
|
|
||||||
|
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||||
|
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nsGroup, found := account.NameServerGroups[nsGroupID]
|
||||||
|
if found {
|
||||||
|
return nsGroup.Copy(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
|
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newNSGroup := &nbdns.NameServerGroup{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Name: name,
|
||||||
|
Description: description,
|
||||||
|
NameServers: nameServerList,
|
||||||
|
Groups: groups,
|
||||||
|
Enabled: enabled,
|
||||||
|
Primary: primary,
|
||||||
|
Domains: domains,
|
||||||
|
SearchDomainsEnabled: searchDomainEnabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateNameServerGroup(false, newNSGroup, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if account.NameServerGroups == nil {
|
||||||
|
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
||||||
|
|
||||||
|
account.Network.IncSerial()
|
||||||
|
err = am.Store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
am.StoreEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||||
|
|
||||||
|
return newNSGroup.Copy(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveNameServerGroup saves nameserver group
|
||||||
|
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
if nsGroupToSave == nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateNameServerGroup(true, nsGroupToSave, account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
||||||
|
|
||||||
|
account.Network.IncSerial()
|
||||||
|
err = am.Store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
am.StoreEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||||
|
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
nsGroup := account.NameServerGroups[nsGroupID]
|
||||||
|
if nsGroup == nil {
|
||||||
|
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
|
||||||
|
}
|
||||||
|
delete(account.NameServerGroups, nsGroupID)
|
||||||
|
|
||||||
|
account.Network.IncSerial()
|
||||||
|
err = am.Store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
|
am.StoreEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListNameServerGroups returns a list of nameserver groups from account
|
||||||
|
func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
|
||||||
|
for _, item := range account.NameServerGroups {
|
||||||
|
nsGroups = append(nsGroups, item.Copy())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsGroups, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
||||||
|
nsGroupID := ""
|
||||||
|
if existingGroup {
|
||||||
|
nsGroupID = nameserverGroup.ID
|
||||||
|
_, found := account.NameServerGroups[nsGroupID]
|
||||||
|
if !found {
|
||||||
|
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateNSList(nameserverGroup.NameServers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateGroups(nameserverGroup.Groups, account.Groups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
||||||
|
if !primary && len(domains) == 0 {
|
||||||
|
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||||
|
" it should be primary or have at least one domain")
|
||||||
|
}
|
||||||
|
if primary && len(domains) != 0 {
|
||||||
|
return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
|
||||||
|
" you should set either primary or domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
if primary && searchDomainsEnabled {
|
||||||
|
return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and search domains is enabled,"+
|
||||||
|
" you should not set search domains for primary nameservers")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range domains {
|
||||||
|
if err := validateDomain(domain); err != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
||||||
|
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
||||||
|
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, nsGroup := range nsGroupMap {
|
||||||
|
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
||||||
|
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateNSList(list []nbdns.NameServer) error {
|
||||||
|
nsListLenght := len(list)
|
||||||
|
if nsListLenght == 0 || nsListLenght > 2 {
|
||||||
|
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateGroups(list []string, groups map[string]*Group) error {
|
||||||
|
if len(list) == 0 {
|
||||||
|
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range list {
|
||||||
|
if id == "" {
|
||||||
|
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for groupID := range groups {
|
||||||
|
if id == groupID {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateDomain(domain string) error {
|
||||||
|
domainMatcher := regexp.MustCompile(domainPattern)
|
||||||
|
if !domainMatcher.MatchString(domain) {
|
||||||
|
return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces")
|
||||||
|
}
|
||||||
|
|
||||||
|
labels, valid := dns.IsDomainName(domain)
|
||||||
|
if !valid {
|
||||||
|
return errors.New("invalid domain name")
|
||||||
|
}
|
||||||
|
|
||||||
|
if labels < 2 {
|
||||||
|
return errors.New("domain should consists of a minimum of two labels")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
148
management/server/management_refactor/server/network/network.go
Normal file
148
management/server/management_refactor/server/network/network.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/c-robinson/iplib"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/management_refactor/server/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SubnetSize is a size of the subnet of the global network, e.g. 100.77.0.0/16
|
||||||
|
SubnetSize = 16
|
||||||
|
// NetSize is a global network size 100.64.0.0/10
|
||||||
|
NetSize = 10
|
||||||
|
|
||||||
|
// AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.64.30.1/32)
|
||||||
|
AllowedIPsFormat = "%s/32"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NetworkMap struct {
|
||||||
|
Peers []*peers.Peer
|
||||||
|
Network *Network
|
||||||
|
Routes []*route.Route
|
||||||
|
DNSConfig nbdns.Config
|
||||||
|
OfflinePeers []*peers.Peer
|
||||||
|
FirewallRules []*FirewallRule
|
||||||
|
}
|
||||||
|
|
||||||
|
type Network struct {
|
||||||
|
Identifier string `json:"id"`
|
||||||
|
Net net.IPNet `gorm:"serializer:gob"`
|
||||||
|
Dns string
|
||||||
|
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||||
|
// Used to synchronize state to the client apps.
|
||||||
|
Serial uint64
|
||||||
|
|
||||||
|
mu sync.Mutex `json:"-" gorm:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||||
|
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
|
||||||
|
func NewNetwork() *Network {
|
||||||
|
|
||||||
|
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
|
||||||
|
sub, _ := n.Subnet(SubnetSize)
|
||||||
|
|
||||||
|
s := rand.NewSource(time.Now().Unix())
|
||||||
|
r := rand.New(s)
|
||||||
|
intn := r.Intn(len(sub))
|
||||||
|
|
||||||
|
return &Network{
|
||||||
|
Identifier: xid.New().String(),
|
||||||
|
Net: sub[intn].IPNet,
|
||||||
|
Dns: "",
|
||||||
|
Serial: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncSerial increments Serial by 1 reflecting that the network state has been changed
|
||||||
|
func (n *Network) IncSerial() {
|
||||||
|
n.mu.Lock()
|
||||||
|
defer n.mu.Unlock()
|
||||||
|
n.Serial = n.Serial + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CurrentSerial returns the Network.Serial of the network (latest state id)
|
||||||
|
func (n *Network) CurrentSerial() uint64 {
|
||||||
|
n.mu.Lock()
|
||||||
|
defer n.mu.Unlock()
|
||||||
|
return n.Serial
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Network) Copy() *Network {
|
||||||
|
return &Network{
|
||||||
|
Identifier: n.Identifier,
|
||||||
|
Net: n.Net,
|
||||||
|
Dns: n.Dns,
|
||||||
|
Serial: n.Serial,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllocatePeerIP pics an available IP from an net.IPNet.
|
||||||
|
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
|
||||||
|
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
|
||||||
|
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
||||||
|
takenIPMap := make(map[string]struct{})
|
||||||
|
takenIPMap[ipNet.IP.String()] = struct{}{}
|
||||||
|
for _, ip := range takenIps {
|
||||||
|
takenIPMap[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, _ := generateIPs(&ipNet, takenIPMap)
|
||||||
|
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// pick a random IP
|
||||||
|
s := rand.NewSource(time.Now().Unix())
|
||||||
|
r := rand.New(s)
|
||||||
|
intn := r.Intn(len(ips))
|
||||||
|
|
||||||
|
return ips[intn], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list
|
||||||
|
func generateIPs(ipNet *net.IPNet, exclusions map[string]struct{}) ([]net.IP, int) {
|
||||||
|
|
||||||
|
var ips []net.IP
|
||||||
|
for ip := ipNet.IP.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
|
||||||
|
if _, ok := exclusions[ip.String()]; !ok && ip[3] != 0 {
|
||||||
|
ips = append(ips, copyIP(ip))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove network address, broadcast and Fake DNS resolver address
|
||||||
|
lenIPs := len(ips)
|
||||||
|
switch {
|
||||||
|
case lenIPs < 2:
|
||||||
|
return ips, lenIPs
|
||||||
|
case lenIPs < 3:
|
||||||
|
return ips[1 : len(ips)-1], lenIPs - 2
|
||||||
|
default:
|
||||||
|
return ips[1 : len(ips)-2], lenIPs - 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyIP(ip net.IP) net.IP {
|
||||||
|
dup := make(net.IP, len(ip))
|
||||||
|
copy(dup, ip)
|
||||||
|
return dup
|
||||||
|
}
|
||||||
|
|
||||||
|
func incIP(ip net.IP) {
|
||||||
|
for j := len(ip) - 1; j >= 0; j-- {
|
||||||
|
ip[j]++
|
||||||
|
if ip[j] > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/management_refactor/server/access_control"
|
||||||
|
"github.com/netbirdio/netbird/management/server/management_refactor/server/peers"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NetworkManager interface {
|
||||||
|
GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultNetworkManager struct {
|
||||||
|
accessControlManager access_control.AccessControlManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nm *DefaultNetworkManager) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
|
||||||
|
aclPeers, firewallRules := getPeerConnectionResources(peerID)
|
||||||
|
// exclude expired peers
|
||||||
|
var peersToConnect []*peers.Peer
|
||||||
|
var expiredPeers []*peers.Peer
|
||||||
|
for _, p := range aclPeers {
|
||||||
|
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||||
|
if a.Settings.PeerLoginExpirationEnabled && expired {
|
||||||
|
expiredPeers = append(expiredPeers, p)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peersToConnect = append(peersToConnect, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
routesUpdate := a.getRoutesToSync(peerID, peersToConnect)
|
||||||
|
|
||||||
|
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
||||||
|
dnsUpdate := nbdns.Config{
|
||||||
|
ServiceEnable: dnsManagementStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsManagementStatus {
|
||||||
|
var zones []nbdns.CustomZone
|
||||||
|
peersCustomZone := getPeersCustomZone(a, dnsDomain)
|
||||||
|
if peersCustomZone.Domain != "" {
|
||||||
|
zones = append(zones, peersCustomZone)
|
||||||
|
}
|
||||||
|
dnsUpdate.CustomZones = zones
|
||||||
|
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &NetworkMap{
|
||||||
|
Peers: peersToConnect,
|
||||||
|
Network: a.Network.Copy(),
|
||||||
|
Routes: routesUpdate,
|
||||||
|
DNSConfig: dnsUpdate,
|
||||||
|
OfflinePeers: expiredPeers,
|
||||||
|
FirewallRules: firewallRules,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPeerConnectionResources for a given peer
|
||||||
|
//
|
||||||
|
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||||
|
func (nm *DefaultNetworkManager) getPeerConnectionResources(peerID string) ([]*Peer, []*FirewallRule) {
|
||||||
|
generateResources, getAccumulatedResources := a.connResourcesGenerator()
|
||||||
|
for _, policy := range nm.accessControlManager.Policies {
|
||||||
|
if !policy.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
if !rule.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID)
|
||||||
|
destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID)
|
||||||
|
|
||||||
|
if rule.Bidirectional {
|
||||||
|
if peerInSources {
|
||||||
|
generateResources(rule, destinationPeers, firewallRuleDirectionIN)
|
||||||
|
}
|
||||||
|
if peerInDestinations {
|
||||||
|
generateResources(rule, sourcePeers, firewallRuleDirectionOUT)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerInSources {
|
||||||
|
generateResources(rule, destinationPeers, firewallRuleDirectionOUT)
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerInDestinations {
|
||||||
|
generateResources(rule, sourcePeers, firewallRuleDirectionIN)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return getAccumulatedResources()
|
||||||
|
}
|
||||||
|
|
||||||
|
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||||
|
//
|
||||||
|
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
||||||
|
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
||||||
|
// generated. The accumulator function returns the result of all the generator calls.
|
||||||
|
func (nm *DefaultNetworkManager) connResourcesGenerator() (func(*access_control.PolicyRule, []*peers.Peer, int), func() ([]*peers.Peer, []*access_control.FirewallRule)) {
|
||||||
|
rulesExists := make(map[string]struct{})
|
||||||
|
peersExists := make(map[string]struct{})
|
||||||
|
rules := make([]*FirewallRule, 0)
|
||||||
|
peers := make([]*peers.Peer, 0)
|
||||||
|
|
||||||
|
all, err := a.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get group all: %v", err)
|
||||||
|
all = &Group{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(rule *PolicyRule, groupPeers []*Peer, direction int) {
|
||||||
|
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||||
|
for _, peer := range groupPeers {
|
||||||
|
if peer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := peersExists[peer.ID]; !ok {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
peersExists[peer.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
fr := FirewallRule{
|
||||||
|
PeerIP: peer.IP.String(),
|
||||||
|
Direction: direction,
|
||||||
|
Action: string(rule.Action),
|
||||||
|
Protocol: string(rule.Protocol),
|
||||||
|
}
|
||||||
|
|
||||||
|
if isAll {
|
||||||
|
fr.PeerIP = "0.0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||||
|
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
|
||||||
|
if _, ok := rulesExists[ruleID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rulesExists[ruleID] = struct{}{}
|
||||||
|
|
||||||
|
if len(rule.Ports) == 0 {
|
||||||
|
rules = append(rules, &fr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, port := range rule.Ports {
|
||||||
|
pr := fr // clone rule and add set new port
|
||||||
|
pr.Port = port
|
||||||
|
rules = append(rules, &pr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, func() ([]*Peer, []*FirewallRule) {
|
||||||
|
return peers, rules
|
||||||
|
}
|
||||||
|
}
|
||||||
225
management/server/management_refactor/server/peers/ephemeral.go
Normal file
225
management/server/management_refactor/server/peers/ephemeral.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
package peers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/management_refactor/accounts"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ephemeralLifeTime = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
timeNow = time.Now
|
||||||
|
)
|
||||||
|
|
||||||
|
type ephemeralPeer struct {
|
||||||
|
id string
|
||||||
|
account *accounts.Account
|
||||||
|
deadline time.Time
|
||||||
|
next *ephemeralPeer
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
|
||||||
|
// in worst case we will get invalid error message in this manager.
|
||||||
|
|
||||||
|
// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted
|
||||||
|
// automatically. Inactivity means the peer disconnected from the Management server.
|
||||||
|
type EphemeralManager struct {
|
||||||
|
store Store
|
||||||
|
accountManager AccountManager
|
||||||
|
|
||||||
|
headPeer *ephemeralPeer
|
||||||
|
tailPeer *ephemeralPeer
|
||||||
|
peersLock sync.Mutex
|
||||||
|
timer *time.Timer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEphemeralManager instantiate new EphemeralManager
|
||||||
|
func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager {
|
||||||
|
return &EphemeralManager{
|
||||||
|
store: store,
|
||||||
|
accountManager: accountManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||||
|
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||||
|
// head.
|
||||||
|
func (e *EphemeralManager) LoadInitialPeers() {
|
||||||
|
e.peersLock.Lock()
|
||||||
|
defer e.peersLock.Unlock()
|
||||||
|
|
||||||
|
e.loadEphemeralPeers()
|
||||||
|
if e.headPeer != nil {
|
||||||
|
e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop timer
|
||||||
|
func (e *EphemeralManager) Stop() {
|
||||||
|
e.peersLock.Lock()
|
||||||
|
defer e.peersLock.Unlock()
|
||||||
|
|
||||||
|
if e.timer != nil {
|
||||||
|
e.timer.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
|
||||||
|
// is active the manager will not delete it while it is active.
|
||||||
|
func (e *EphemeralManager) OnPeerConnected(peer *Peer) {
|
||||||
|
if !peer.Ephemeral {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("remove peer from ephemeral list: %s", peer.ID)
|
||||||
|
|
||||||
|
e.peersLock.Lock()
|
||||||
|
defer e.peersLock.Unlock()
|
||||||
|
|
||||||
|
e.removePeer(peer.ID)
|
||||||
|
|
||||||
|
// stop the unnecessary timer
|
||||||
|
if e.headPeer == nil && e.timer != nil {
|
||||||
|
e.timer.Stop()
|
||||||
|
e.timer = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
|
||||||
|
// is inactive it will be deleted after the ephemeralLifeTime period.
|
||||||
|
func (e *EphemeralManager) OnPeerDisconnected(peer *Peer) {
|
||||||
|
if !peer.Ephemeral {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||||
|
|
||||||
|
a, err := e.store.GetAccountByPeerID(peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add peer to ephemeral list: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.peersLock.Lock()
|
||||||
|
defer e.peersLock.Unlock()
|
||||||
|
|
||||||
|
if e.isPeerOnList(peer.ID) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.addPeer(peer.ID, a, newDeadLine())
|
||||||
|
if e.timer == nil {
|
||||||
|
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EphemeralManager) loadEphemeralPeers() {
|
||||||
|
accounts := e.store.GetAllAccounts()
|
||||||
|
t := newDeadLine()
|
||||||
|
count := 0
|
||||||
|
for _, a := range accounts {
|
||||||
|
for id, p := range a.Peers {
|
||||||
|
if p.Ephemeral {
|
||||||
|
count++
|
||||||
|
e.addPeer(id, a, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("loaded ephemeral peer(s): %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EphemeralManager) cleanup() {
|
||||||
|
log.Tracef("on ephemeral cleanup")
|
||||||
|
deletePeers := make(map[string]*ephemeralPeer)
|
||||||
|
|
||||||
|
e.peersLock.Lock()
|
||||||
|
now := timeNow()
|
||||||
|
for p := e.headPeer; p != nil; p = p.next {
|
||||||
|
if now.Before(p.deadline) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
deletePeers[p.id] = p
|
||||||
|
e.headPeer = p.next
|
||||||
|
if p.next == nil {
|
||||||
|
e.tailPeer = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.headPeer != nil {
|
||||||
|
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
|
||||||
|
} else {
|
||||||
|
e.timer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
e.peersLock.Unlock()
|
||||||
|
|
||||||
|
for id, p := range deletePeers {
|
||||||
|
log.Debugf("delete ephemeral peer: %s", id)
|
||||||
|
err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator)
|
||||||
|
if err != nil {
|
||||||
|
log.Tracef("failed to delete ephemeral peer: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) {
|
||||||
|
ep := &ephemeralPeer{
|
||||||
|
id: id,
|
||||||
|
account: account,
|
||||||
|
deadline: deadline,
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.headPeer == nil {
|
||||||
|
e.headPeer = ep
|
||||||
|
}
|
||||||
|
if e.tailPeer != nil {
|
||||||
|
e.tailPeer.next = ep
|
||||||
|
}
|
||||||
|
e.tailPeer = ep
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EphemeralManager) removePeer(id string) {
|
||||||
|
if e.headPeer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.headPeer.id == id {
|
||||||
|
e.headPeer = e.headPeer.next
|
||||||
|
if e.tailPeer.id == id {
|
||||||
|
e.tailPeer = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for p := e.headPeer; p.next != nil; p = p.next {
|
||||||
|
if p.next.id == id {
|
||||||
|
// if we remove the last element from the chain then set the last-1 as tail
|
||||||
|
if e.tailPeer.id == id {
|
||||||
|
e.tailPeer = p
|
||||||
|
}
|
||||||
|
p.next = p.next.next
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||||
|
for p := e.headPeer; p != nil; p = p.next {
|
||||||
|
if p.id == id {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDeadLine() time.Time {
|
||||||
|
return timeNow().Add(ephemeralLifeTime)
|
||||||
|
}
|
||||||
186
management/server/management_refactor/server/peers/peer.go
Normal file
186
management/server/management_refactor/server/peers/peer.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package peers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Peer represents a machine connected to the network.
|
||||||
|
// The Peer is a WireGuard peer identified by a public key
|
||||||
|
type Peer struct {
|
||||||
|
// ID is an internal ID of the peer
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
// AccountID is a reference to Account that this object belongs
|
||||||
|
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
|
||||||
|
// WireGuard public key
|
||||||
|
Key string `gorm:"index"`
|
||||||
|
// A setup key this peer was registered with
|
||||||
|
SetupKey string
|
||||||
|
// IP address of the Peer
|
||||||
|
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
|
||||||
|
// Meta is a Peer system meta data
|
||||||
|
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
|
// Name is peer's name (machine name)
|
||||||
|
Name string
|
||||||
|
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||||
|
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||||
|
DNSLabel string
|
||||||
|
// Status peer's management connection status
|
||||||
|
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||||
|
// The user ID that registered the peer
|
||||||
|
UserID string
|
||||||
|
// SSHKey is a public SSH key of the peer
|
||||||
|
SSHKey string
|
||||||
|
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||||
|
SSHEnabled bool
|
||||||
|
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||||
|
// Works with LastLogin
|
||||||
|
LoginExpirationEnabled bool
|
||||||
|
// LastLogin the time when peer performed last login operation
|
||||||
|
LastLogin time.Time
|
||||||
|
// Indicate ephemeral peer attribute
|
||||||
|
Ephemeral bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerSystemMeta is a metadata of a Peer machine system
|
||||||
|
type PeerSystemMeta struct {
|
||||||
|
Hostname string
|
||||||
|
GoOS string
|
||||||
|
Kernel string
|
||||||
|
Core string
|
||||||
|
Platform string
|
||||||
|
OS string
|
||||||
|
WtVersion string
|
||||||
|
UIVersion string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
|
||||||
|
return p.Hostname == other.Hostname &&
|
||||||
|
p.GoOS == other.GoOS &&
|
||||||
|
p.Kernel == other.Kernel &&
|
||||||
|
p.Core == other.Core &&
|
||||||
|
p.Platform == other.Platform &&
|
||||||
|
p.OS == other.OS &&
|
||||||
|
p.WtVersion == other.WtVersion &&
|
||||||
|
p.UIVersion == other.UIVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user.
|
||||||
|
func (p *Peer) AddedWithSSOLogin() bool {
|
||||||
|
return p.UserID != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy copies Peer object
|
||||||
|
func (p *Peer) Copy() *Peer {
|
||||||
|
peerStatus := p.Status
|
||||||
|
if peerStatus != nil {
|
||||||
|
peerStatus = p.Status.Copy()
|
||||||
|
}
|
||||||
|
return &Peer{
|
||||||
|
ID: p.ID,
|
||||||
|
AccountID: p.AccountID,
|
||||||
|
Key: p.Key,
|
||||||
|
SetupKey: p.SetupKey,
|
||||||
|
IP: p.IP,
|
||||||
|
Meta: p.Meta,
|
||||||
|
Name: p.Name,
|
||||||
|
DNSLabel: p.DNSLabel,
|
||||||
|
Status: peerStatus,
|
||||||
|
UserID: p.UserID,
|
||||||
|
SSHKey: p.SSHKey,
|
||||||
|
SSHEnabled: p.SSHEnabled,
|
||||||
|
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||||
|
LastLogin: p.LastLogin,
|
||||||
|
Ephemeral: p.Ephemeral,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
||||||
|
// returns true if meta was updated, false otherwise
|
||||||
|
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool {
|
||||||
|
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
||||||
|
if meta.UIVersion == "" {
|
||||||
|
meta.UIVersion = p.Meta.UIVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Meta.isEqual(meta) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
p.Meta = meta
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkLoginExpired marks peer's status expired or not
|
||||||
|
func (p *Peer) MarkLoginExpired(expired bool) {
|
||||||
|
newStatus := p.Status.Copy()
|
||||||
|
newStatus.LoginExpired = expired
|
||||||
|
if expired {
|
||||||
|
newStatus.Connected = false
|
||||||
|
}
|
||||||
|
p.Status = newStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginExpired indicates whether the peer's login has expired or not.
|
||||||
|
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
|
||||||
|
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
|
||||||
|
// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
|
||||||
|
// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
|
||||||
|
// Only peers added by interactive SSO login can be expired.
|
||||||
|
func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
|
||||||
|
if !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
expiresAt := p.LastLogin.Add(expiresIn)
|
||||||
|
now := time.Now()
|
||||||
|
timeLeft := expiresAt.Sub(now)
|
||||||
|
return timeLeft <= 0, timeLeft
|
||||||
|
}
|
||||||
|
|
||||||
|
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
|
||||||
|
func (p *Peer) FQDN(dnsDomain string) string {
|
||||||
|
if dnsDomain == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventMeta returns activity event meta related to the peer
|
||||||
|
func (p *Peer) EventMeta(dnsDomain string) map[string]any {
|
||||||
|
return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy PeerStatus
|
||||||
|
func (p *PeerStatus) Copy() *PeerStatus {
|
||||||
|
return &PeerStatus{
|
||||||
|
LastSeen: p.LastSeen,
|
||||||
|
Connected: p.Connected,
|
||||||
|
LoginExpired: p.LoginExpired,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLastLogin and set login expired false
|
||||||
|
func (p *Peer) UpdateLastLogin() {
|
||||||
|
p.LastLogin = time.Now().UTC()
|
||||||
|
newStatus := p.Status.Copy()
|
||||||
|
newStatus.LoginExpired = false
|
||||||
|
p.Status = newStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) CheckAndUpdatePeerSSHKey(newSSHKey string) bool {
|
||||||
|
if len(newSSHKey) == 0 {
|
||||||
|
log.Debugf("no new SSH key provided for peer %s, skipping update", p.ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.SSHKey == newSSHKey {
|
||||||
|
log.Debugf("same SSH key provided for peer %s, skipping update", p.ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
p.SSHKey = newSSHKey
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
package peers
|
||||||
|
|
||||||
|
type PeerRepository interface {
|
||||||
|
findPeerByPubKey(pubKey string) (Peer, error)
|
||||||
|
updatePeer(peer Peer) error
|
||||||
|
}
|
||||||
@@ -0,0 +1,214 @@
|
|||||||
|
package peers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/management_refactor/server/accounts"
|
||||||
|
"github.com/netbirdio/netbird/management/server/management_refactor/server/events"
|
||||||
|
"github.com/netbirdio/netbird/management/server/management_refactor/server/users"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PeerLogin used as a data object between the gRPC API and AccountManager on Login request.
|
||||||
|
type PeerLogin struct {
|
||||||
|
// WireGuardPubKey is a peers WireGuard public key
|
||||||
|
WireGuardPubKey string
|
||||||
|
// SSHKey is a peer's ssh key. Can be empty (e.g., old version do not provide it, or this feature is disabled)
|
||||||
|
SSHKey string
|
||||||
|
// Meta is the system information passed by peer, must be always present.
|
||||||
|
Meta PeerSystemMeta
|
||||||
|
// UserID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required.
|
||||||
|
UserID string
|
||||||
|
// AccountID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required.
|
||||||
|
AccountID string
|
||||||
|
// SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required.
|
||||||
|
SetupKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerStatus struct {
|
||||||
|
// LastSeen is the last time peer was connected to the management service
|
||||||
|
LastSeen time.Time
|
||||||
|
// Connected indicates whether peer is connected to the management service or not
|
||||||
|
Connected bool
|
||||||
|
// LoginExpired
|
||||||
|
LoginExpired bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerSync used as a data object between the gRPC API and AccountManager on Sync request.
|
||||||
|
type PeerSync struct {
|
||||||
|
// WireGuardPubKey is a peers WireGuard public key
|
||||||
|
WireGuardPubKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeersManager interface {
|
||||||
|
LoginPeer(login PeerLogin) (*Peer, *accounts.NetworkMap, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultPeersManager struct {
|
||||||
|
repository PeerRepository
|
||||||
|
userManager users.UserManager
|
||||||
|
accountManager accounts.AccountManager
|
||||||
|
eventsManager events.EventsManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginPeer logs in or registers a peer.
|
||||||
|
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
|
||||||
|
func (pm *DefaultPeersManager) LoginPeer(login PeerLogin) (*Peer, *accounts.NetworkMap, error) {
|
||||||
|
|
||||||
|
peer, err := pm.repository.findPeerByPubKey(login.WireGuardPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.AddedWithSSOLogin() {
|
||||||
|
user, err := pm.userManager.GetUser(peer.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if user.IsBlocked() {
|
||||||
|
return nil, nil, status.Errorf(status.PermissionDenied, "user is blocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := pm.accountManager.GetAccount(peer.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// this flag prevents unnecessary calls to the persistent store.
|
||||||
|
shouldStorePeer := false
|
||||||
|
updateRemotePeers := false
|
||||||
|
if peerLoginExpired(peer, account) {
|
||||||
|
err = checkAuth(login.UserID, peer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
// If peer was expired before and if it reached this point, it is re-authenticated.
|
||||||
|
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
||||||
|
peer.UpdateLastLogin()
|
||||||
|
updateRemotePeers = true
|
||||||
|
shouldStorePeer = true
|
||||||
|
|
||||||
|
pm.eventsManager.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(pm.accountManager.GetDNSDomain()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.UpdateMetaIfNew(login.Meta) {
|
||||||
|
shouldStorePeer = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.CheckAndUpdatePeerSSHKey(login.SSHKey) {
|
||||||
|
shouldStorePeer = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldStorePeer {
|
||||||
|
err := pm.repository.updatePeer(peer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if updateRemotePeers {
|
||||||
|
am.updateAccountPeers(account)
|
||||||
|
}
|
||||||
|
return peer, account.GetPeerNetworkMap(peer.ID, pm.accountManager.GetDNSDomain()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||||
|
func (pm *DefaultPeersManager) SyncPeer(sync PeerSync) (*Peer, *accounts.NetworkMap, error) {
|
||||||
|
// we found the peer, and we follow a normal login flow
|
||||||
|
// unlock := am.Store.AcquireAccountLock(account.Id)
|
||||||
|
// defer unlock()
|
||||||
|
|
||||||
|
peer, err := pm.repository.findPeerByPubKey(sync.WireGuardPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.AddedWithSSOLogin() {
|
||||||
|
user, err := pm.userManager.GetUser(peer.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if user.IsBlocked() {
|
||||||
|
return nil, nil, status.Errorf(status.PermissionDenied, "user is blocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := pm.accountManager.GetAccount(peer.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerLoginExpired(peer, account) {
|
||||||
|
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &peer, account.GetPeerNetworkMap(peer.ID, pm.accountManager.GetDNSDomain()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *DefaultPeersManager) GetNetworkMap(peerID string, dnsDomain string) (*accounts.NetworkMap, error) {
|
||||||
|
aclPeers, firewallRules := a.getPeerConnectionResources(peerID)
|
||||||
|
// exclude expired peers
|
||||||
|
var peersToConnect []*Peer
|
||||||
|
var expiredPeers []*Peer
|
||||||
|
for _, p := range aclPeers {
|
||||||
|
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||||
|
if a.Settings.PeerLoginExpirationEnabled && expired {
|
||||||
|
expiredPeers = append(expiredPeers, p)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peersToConnect = append(peersToConnect, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
routesUpdate := a.getRoutesToSync(peerID, peersToConnect)
|
||||||
|
|
||||||
|
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
||||||
|
dnsUpdate := nbdns.Config{
|
||||||
|
ServiceEnable: dnsManagementStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsManagementStatus {
|
||||||
|
var zones []nbdns.CustomZone
|
||||||
|
peersCustomZone := getPeersCustomZone(a, dnsDomain)
|
||||||
|
if peersCustomZone.Domain != "" {
|
||||||
|
zones = append(zones, peersCustomZone)
|
||||||
|
}
|
||||||
|
dnsUpdate.CustomZones = zones
|
||||||
|
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &NetworkMap{
|
||||||
|
Peers: peersToConnect,
|
||||||
|
Network: a.Network.Copy(),
|
||||||
|
Routes: routesUpdate,
|
||||||
|
DNSConfig: dnsUpdate,
|
||||||
|
OfflinePeers: expiredPeers,
|
||||||
|
FirewallRules: firewallRules,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerLoginExpired(peer Peer, account accounts.Account) bool {
|
||||||
|
expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration)
|
||||||
|
expired = account.Settings.PeerLoginExpirationEnabled && expired
|
||||||
|
if expired || peer.Status.LoginExpired {
|
||||||
|
log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkAuth(loginUserID string, peer Peer) error {
|
||||||
|
if loginUserID == "" {
|
||||||
|
// absence of a user ID indicates that JWT wasn't provided.
|
||||||
|
return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||||
|
}
|
||||||
|
if peer.UserID != loginUserID {
|
||||||
|
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
|
||||||
|
return status.Errorf(status.Unauthenticated, "can't login")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
package peers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
const channelBufferSize = 100
|
||||||
|
|
||||||
|
type UpdateMessage struct {
|
||||||
|
Update *proto.SyncResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeersUpdateManager struct {
|
||||||
|
// PeerChannels is an update channel indexed by Peer.ID
|
||||||
|
PeerChannels map[string]chan *UpdateMessage
|
||||||
|
// channelsMux keeps the mutex to access PeerChannels
|
||||||
|
channelsMux *sync.Mutex
|
||||||
|
// metrics provides method to collect application metrics
|
||||||
|
metrics telemetry.AppMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||||
|
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||||
|
return &PeersUpdateManager{
|
||||||
|
PeerChannels: make(map[string]chan *UpdateMessage),
|
||||||
|
channelsMux: &sync.Mutex{},
|
||||||
|
metrics: metrics,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendUpdate sends update message to the peer's channel
|
||||||
|
func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) {
|
||||||
|
start := time.Now()
|
||||||
|
var found, dropped bool
|
||||||
|
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
defer func() {
|
||||||
|
p.channelsMux.Unlock()
|
||||||
|
if p.metrics != nil {
|
||||||
|
p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if channel, ok := p.PeerChannels[peerID]; ok {
|
||||||
|
found = true
|
||||||
|
select {
|
||||||
|
case channel <- update:
|
||||||
|
log.Debugf("update was sent to channel for peer %s", peerID)
|
||||||
|
default:
|
||||||
|
dropped = true
|
||||||
|
log.Warnf("channel for peer %s is %d full", peerID, len(channel))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("peer %s has no channel", peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||||
|
func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
closed := false
|
||||||
|
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
defer func() {
|
||||||
|
p.channelsMux.Unlock()
|
||||||
|
if p.metrics != nil {
|
||||||
|
p.metrics.UpdateChannelMetrics().CountCreateChannelDuration(time.Since(start), closed)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if channel, ok := p.PeerChannels[peerID]; ok {
|
||||||
|
closed = true
|
||||||
|
delete(p.PeerChannels, peerID)
|
||||||
|
close(channel)
|
||||||
|
}
|
||||||
|
// mbragin: todo shouldn't it be more? or configurable?
|
||||||
|
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||||
|
p.PeerChannels[peerID] = channel
|
||||||
|
|
||||||
|
log.Debugf("opened updates channel for a peer %s", peerID)
|
||||||
|
|
||||||
|
return channel
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PeersUpdateManager) closeChannel(peerID string) {
|
||||||
|
if channel, ok := p.PeerChannels[peerID]; ok {
|
||||||
|
delete(p.PeerChannels, peerID)
|
||||||
|
close(channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("closed updates channel of a peer %s", peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseChannels closes updates channel for each given peer
|
||||||
|
func (p *PeersUpdateManager) CloseChannels(peerIDs []string) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
defer func() {
|
||||||
|
p.channelsMux.Unlock()
|
||||||
|
if p.metrics != nil {
|
||||||
|
p.metrics.UpdateChannelMetrics().CountCloseChannelsDuration(time.Since(start), len(peerIDs))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, id := range peerIDs {
|
||||||
|
p.closeChannel(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseChannel closes updates channel of a given peer
|
||||||
|
func (p *PeersUpdateManager) CloseChannel(peerID string) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
defer func() {
|
||||||
|
p.channelsMux.Unlock()
|
||||||
|
if p.metrics != nil {
|
||||||
|
p.metrics.UpdateChannelMetrics().CountCloseChannelDuration(time.Since(start))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
p.closeChannel(peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllConnectedPeers returns a copy of the connected peers map
|
||||||
|
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
p.channelsMux.Lock()
|
||||||
|
|
||||||
|
m := make(map[string]struct{})
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
p.channelsMux.Unlock()
|
||||||
|
if p.metrics != nil {
|
||||||
|
p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for ID := range p.PeerChannels {
|
||||||
|
m[ID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
package routes
|
||||||
115
management/server/management_refactor/server/scheduler.go
Normal file
115
management/server/management_refactor/server/scheduler.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Scheduler is an interface which implementations can schedule and cancel jobs
|
||||||
|
type Scheduler interface {
|
||||||
|
Cancel(IDs []string)
|
||||||
|
Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockScheduler is a mock implementation of Scheduler
|
||||||
|
type MockScheduler struct {
|
||||||
|
CancelFunc func(IDs []string)
|
||||||
|
ScheduleFunc func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel mocks the Cancel function of the Scheduler interface
|
||||||
|
func (mock *MockScheduler) Cancel(IDs []string) {
|
||||||
|
if mock.CancelFunc != nil {
|
||||||
|
mock.CancelFunc(IDs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("MockScheduler doesn't have Cancel function defined ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schedule mocks the Schedule function of the Scheduler interface
|
||||||
|
func (mock *MockScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||||
|
if mock.ScheduleFunc != nil {
|
||||||
|
mock.ScheduleFunc(in, ID, job)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("MockScheduler doesn't have Schedule function defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
|
||||||
|
type DefaultScheduler struct {
|
||||||
|
// jobs map holds cancellation channels indexed by the job ID
|
||||||
|
jobs map[string]chan struct{}
|
||||||
|
mu *sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultScheduler creates an instance of a DefaultScheduler
|
||||||
|
func NewDefaultScheduler() *DefaultScheduler {
|
||||||
|
return &DefaultScheduler{
|
||||||
|
jobs: make(map[string]chan struct{}),
|
||||||
|
mu: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wm *DefaultScheduler) cancel(ID string) bool {
|
||||||
|
cancel, ok := wm.jobs[ID]
|
||||||
|
if ok {
|
||||||
|
delete(wm.jobs, ID)
|
||||||
|
select {
|
||||||
|
case cancel <- struct{}{}:
|
||||||
|
log.Debugf("cancelled scheduled job %s", ID)
|
||||||
|
default:
|
||||||
|
log.Warnf("couldn't cancel job %s because there was no routine listening on the cancel event", ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel cancels the scheduled job by ID if present.
|
||||||
|
// If job wasn't found the function returns false.
|
||||||
|
func (wm *DefaultScheduler) Cancel(IDs []string) {
|
||||||
|
wm.mu.Lock()
|
||||||
|
defer wm.mu.Unlock()
|
||||||
|
|
||||||
|
for _, id := range IDs {
|
||||||
|
wm.cancel(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schedule a job to run in some time in the future. If job returns true then it will be scheduled one more time.
|
||||||
|
// If job with the provided ID already exists, a new one won't be scheduled.
|
||||||
|
func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||||
|
wm.mu.Lock()
|
||||||
|
defer wm.mu.Unlock()
|
||||||
|
cancel := make(chan struct{})
|
||||||
|
if _, ok := wm.jobs[ID]; ok {
|
||||||
|
log.Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.",
|
||||||
|
ID, len(wm.jobs))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wm.jobs[ID] = cancel
|
||||||
|
log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs))
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-time.After(in):
|
||||||
|
log.Debugf("time to do a scheduled job %s", ID)
|
||||||
|
runIn, reschedule := job()
|
||||||
|
wm.mu.Lock()
|
||||||
|
defer wm.mu.Unlock()
|
||||||
|
delete(wm.jobs, ID)
|
||||||
|
if reschedule {
|
||||||
|
go wm.Schedule(runIn, ID, job)
|
||||||
|
}
|
||||||
|
case <-cancel:
|
||||||
|
log.Debugf("stopped scheduled job %s ", ID)
|
||||||
|
wm.mu.Lock()
|
||||||
|
defer wm.mu.Unlock()
|
||||||
|
delete(wm.jobs, ID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
package setupkey
|
||||||
@@ -0,0 +1,459 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SqliteStore represents an account storage backed by a Sqlite DB persisted to disk
|
||||||
|
type SqliteStore struct {
|
||||||
|
db *gorm.DB
|
||||||
|
storeFile string
|
||||||
|
accountLocks sync.Map
|
||||||
|
globalAccountLock sync.Mutex
|
||||||
|
metrics telemetry.AppMetrics
|
||||||
|
installationPK int
|
||||||
|
}
|
||||||
|
|
||||||
|
type installation struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
InstallationIDValue string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSqliteStore restores a store from the file located in the datadir
|
||||||
|
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
|
||||||
|
storeStr := "store.db?cache=shared"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||||
|
storeStr = "store.db"
|
||||||
|
}
|
||||||
|
|
||||||
|
file := filepath.Join(dataDir, storeStr)
|
||||||
|
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
PrepareStmt: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sql, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conns := runtime.NumCPU()
|
||||||
|
sql.SetMaxOpenConns(conns) // TODO: make it configurable
|
||||||
|
|
||||||
|
err = db.AutoMigrate(
|
||||||
|
&SetupKey{}, &Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{},
|
||||||
|
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||||
|
&installation{},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir
|
||||||
|
func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
|
||||||
|
store, err := NewSqliteStore(dataDir, metrics)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = store.SaveInstallationID(filestore.InstallationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, account := range filestore.GetAllAccounts() {
|
||||||
|
err := store.SaveAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
|
||||||
|
func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
|
||||||
|
log.Debugf("acquiring global lock")
|
||||||
|
start := time.Now()
|
||||||
|
s.globalAccountLock.Lock()
|
||||||
|
|
||||||
|
unlock = func() {
|
||||||
|
s.globalAccountLock.Unlock()
|
||||||
|
log.Debugf("released global lock in %v", time.Since(start))
|
||||||
|
}
|
||||||
|
|
||||||
|
took := time.Since(start)
|
||||||
|
log.Debugf("took %v to acquire global lock", took)
|
||||||
|
if s.metrics != nil {
|
||||||
|
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
|
||||||
|
}
|
||||||
|
|
||||||
|
return unlock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
|
||||||
|
log.Debugf("acquiring lock for account %s", accountID)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||||
|
mtx := value.(*sync.Mutex)
|
||||||
|
mtx.Lock()
|
||||||
|
|
||||||
|
unlock = func() {
|
||||||
|
mtx.Unlock()
|
||||||
|
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
|
||||||
|
}
|
||||||
|
|
||||||
|
return unlock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) SaveAccount(account *Account) error {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
for _, key := range account.SetupKeys {
|
||||||
|
account.SetupKeysG = append(account.SetupKeysG, *key)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, peer := range account.Peers {
|
||||||
|
peer.ID = id
|
||||||
|
account.PeersG = append(account.PeersG, *peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, user := range account.Users {
|
||||||
|
user.Id = id
|
||||||
|
for id, pat := range user.PATs {
|
||||||
|
pat.ID = id
|
||||||
|
user.PATsG = append(user.PATsG, *pat)
|
||||||
|
}
|
||||||
|
account.UsersG = append(account.UsersG, *user)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, group := range account.Groups {
|
||||||
|
group.ID = id
|
||||||
|
account.GroupsG = append(account.GroupsG, *group)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, rule := range account.Rules {
|
||||||
|
rule.ID = id
|
||||||
|
account.RulesG = append(account.RulesG, *rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, route := range account.Routes {
|
||||||
|
route.ID = id
|
||||||
|
account.RoutesG = append(account.RoutesG, *route)
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, ns := range account.NameServerGroups {
|
||||||
|
ns.ID = id
|
||||||
|
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
result = tx.Select(clause.Associations).Delete(account)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
result = tx.
|
||||||
|
Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
|
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
took := time.Since(start)
|
||||||
|
if s.metrics != nil {
|
||||||
|
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||||
|
}
|
||||||
|
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) SaveInstallationID(ID string) error {
|
||||||
|
installation := installation{InstallationIDValue: ID}
|
||||||
|
installation.ID = uint(s.installationPK)
|
||||||
|
|
||||||
|
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetInstallationID() string {
|
||||||
|
var installation installation
|
||||||
|
|
||||||
|
if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return installation.InstallationIDValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus PeerStatus) error {
|
||||||
|
var peer Peer
|
||||||
|
|
||||||
|
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.Status = &peerStatus
|
||||||
|
|
||||||
|
return s.db.Save(peer).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
|
||||||
|
func (s *SqliteStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteTokenID2UserIDIndex is noop in Sqlite
|
||||||
|
func (s *SqliteStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||||
|
var account Account
|
||||||
|
|
||||||
|
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||||
|
strings.ToLower(domain), true, PrivateCategory)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: rework to not call GetAccount
|
||||||
|
return s.GetAccount(account.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||||
|
var key SetupKey
|
||||||
|
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.AccountID == "" {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.GetAccount(key.AccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
|
||||||
|
var token PersonalAccessToken
|
||||||
|
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
||||||
|
if result.Error != nil {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return token.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||||
|
var token PersonalAccessToken
|
||||||
|
result := s.db.First(&token, "id = ?", tokenID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.UserID == "" {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
result = s.db.Preload("PATsG").First(&user, "id = ?", token.UserID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
|
||||||
|
for _, pat := range user.PATsG {
|
||||||
|
user.PATs[pat.ID] = pat.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetAllAccounts() (all []*Account) {
|
||||||
|
var accounts []Account
|
||||||
|
result := s.db.Find(&accounts)
|
||||||
|
if result.Error != nil {
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, account := range accounts {
|
||||||
|
if acc, err := s.GetAccount(account.Id); err == nil {
|
||||||
|
all = append(all, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
||||||
|
var account Account
|
||||||
|
|
||||||
|
result := s.db.Model(&account).
|
||||||
|
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||||
|
Preload(clause.Associations).
|
||||||
|
First(&account, "id = ?", accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
||||||
|
for i, policy := range account.Policies {
|
||||||
|
var rules []*PolicyRule
|
||||||
|
err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
account.Policies[i].Rules = rules
|
||||||
|
}
|
||||||
|
|
||||||
|
account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG))
|
||||||
|
for _, key := range account.SetupKeysG {
|
||||||
|
account.SetupKeys[key.Key] = key.Copy()
|
||||||
|
}
|
||||||
|
account.SetupKeysG = nil
|
||||||
|
|
||||||
|
account.Peers = make(map[string]*Peer, len(account.PeersG))
|
||||||
|
for _, peer := range account.PeersG {
|
||||||
|
account.Peers[peer.ID] = peer.Copy()
|
||||||
|
}
|
||||||
|
account.PeersG = nil
|
||||||
|
|
||||||
|
account.Users = make(map[string]*User, len(account.UsersG))
|
||||||
|
for _, user := range account.UsersG {
|
||||||
|
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs))
|
||||||
|
for _, pat := range user.PATsG {
|
||||||
|
user.PATs[pat.ID] = pat.Copy()
|
||||||
|
}
|
||||||
|
account.Users[user.Id] = user.Copy()
|
||||||
|
}
|
||||||
|
account.UsersG = nil
|
||||||
|
|
||||||
|
account.Groups = make(map[string]*Group, len(account.GroupsG))
|
||||||
|
for _, group := range account.GroupsG {
|
||||||
|
account.Groups[group.ID] = group.Copy()
|
||||||
|
}
|
||||||
|
account.GroupsG = nil
|
||||||
|
|
||||||
|
account.Rules = make(map[string]*Rule, len(account.RulesG))
|
||||||
|
for _, rule := range account.RulesG {
|
||||||
|
account.Rules[rule.ID] = rule.Copy()
|
||||||
|
}
|
||||||
|
account.RulesG = nil
|
||||||
|
|
||||||
|
account.Routes = make(map[string]*route.Route, len(account.RoutesG))
|
||||||
|
for _, route := range account.RoutesG {
|
||||||
|
account.Routes[route.ID] = route.Copy()
|
||||||
|
}
|
||||||
|
account.RoutesG = nil
|
||||||
|
|
||||||
|
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
||||||
|
for _, ns := range account.NameServerGroupsG {
|
||||||
|
account.NameServerGroups[ns.ID] = ns.Copy()
|
||||||
|
}
|
||||||
|
account.NameServerGroupsG = nil
|
||||||
|
|
||||||
|
return &account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
|
||||||
|
var user User
|
||||||
|
result := s.db.Select("account_id").First(&user, "id = ?", userID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccountID == "" {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.GetAccount(user.AccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||||
|
var peer Peer
|
||||||
|
result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.AccountID == "" {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.GetAccount(peer.AccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||||
|
var peer Peer
|
||||||
|
|
||||||
|
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.AccountID == "" {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.GetAccount(peer.AccountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||||
|
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||||
|
var user User
|
||||||
|
|
||||||
|
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.NotFound, "user %s not found", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
user.LastLogin = lastLogin
|
||||||
|
|
||||||
|
return s.db.Save(user).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close is noop in Sqlite
|
||||||
|
func (s *SqliteStore) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStoreEngine returns SqliteStoreEngine
|
||||||
|
func (s *SqliteStore) GetStoreEngine() StoreEngine {
|
||||||
|
return SqliteStoreEngine
|
||||||
|
}
|
||||||
98
management/server/management_refactor/server/store/store.go
Normal file
98
management/server/management_refactor/server/store/store.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Store interface {
|
||||||
|
GetAllAccounts() []*Account
|
||||||
|
GetAccount(accountID string) (*Account, error)
|
||||||
|
GetAccountByUser(userID string) (*Account, error)
|
||||||
|
GetAccountByPeerPubKey(peerKey string) (*Account, error)
|
||||||
|
GetAccountByPeerID(peerID string) (*Account, error)
|
||||||
|
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
|
||||||
|
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||||
|
GetTokenIDByHashedToken(secret string) (string, error)
|
||||||
|
GetUserByTokenID(tokenID string) (*User, error)
|
||||||
|
SaveAccount(account *Account) error
|
||||||
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||||
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
|
GetInstallationID() string
|
||||||
|
SaveInstallationID(ID string) error
|
||||||
|
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
|
||||||
|
AcquireAccountLock(accountID string) func()
|
||||||
|
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||||
|
AcquireGlobalLock() func()
|
||||||
|
SavePeerStatus(accountID, peerID string, status PeerStatus) error
|
||||||
|
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
|
||||||
|
// Close should close the store persisting all unsaved data.
|
||||||
|
Close() error
|
||||||
|
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||||
|
// This is also a method of metrics.DataSource interface.
|
||||||
|
GetStoreEngine() StoreEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
type StoreEngine string
|
||||||
|
|
||||||
|
const (
|
||||||
|
FileStoreEngine StoreEngine = "jsonfile"
|
||||||
|
SqliteStoreEngine StoreEngine = "sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getStoreEngineFromEnv() StoreEngine {
|
||||||
|
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise rely on the config file.
|
||||||
|
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
|
||||||
|
if !ok {
|
||||||
|
return FileStoreEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
value := StoreEngine(strings.ToLower(kind))
|
||||||
|
|
||||||
|
if value == FileStoreEngine || value == SqliteStoreEngine {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
return FileStoreEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
|
||||||
|
if kind == "" {
|
||||||
|
// fallback to env. Normally this only should be used from tests
|
||||||
|
kind = getStoreEngineFromEnv()
|
||||||
|
}
|
||||||
|
switch kind {
|
||||||
|
case FileStoreEngine:
|
||||||
|
log.Info("using JSON file store engine")
|
||||||
|
return NewFileStore(dataDir, metrics)
|
||||||
|
case SqliteStoreEngine:
|
||||||
|
log.Info("using SQLite store engine")
|
||||||
|
return NewSqliteStore(dataDir, metrics)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported kind of store %s", kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, error) {
|
||||||
|
fstore, err := NewFileStore(dataDir, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
kind := getStoreEngineFromEnv()
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case FileStoreEngine:
|
||||||
|
return fstore, nil
|
||||||
|
case SqliteStoreEngine:
|
||||||
|
return NewSqliteStoreFromFileStore(fstore, dataDir, metrics)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported store engine %s", kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
125
management/server/management_refactor/server/turncredentials.go
Normal file
125
management/server/management_refactor/server/turncredentials.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TURNCredentialsManager used to manage TURN credentials
|
||||||
|
type TURNCredentialsManager interface {
|
||||||
|
GenerateCredentials() TURNCredentials
|
||||||
|
SetupRefresh(peerKey string)
|
||||||
|
CancelRefresh(peerKey string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
||||||
|
type TimeBasedAuthSecretsManager struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
config *TURNConfig
|
||||||
|
updateManager *PeersUpdateManager
|
||||||
|
cancelMap map[string]chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TURNCredentials struct {
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager {
|
||||||
|
return &TimeBasedAuthSecretsManager{
|
||||||
|
mux: sync.Mutex{},
|
||||||
|
config: config,
|
||||||
|
updateManager: updateManager,
|
||||||
|
cancelMap: make(map[string]chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
|
||||||
|
func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials {
|
||||||
|
mac := hmac.New(sha1.New, []byte(m.config.Secret))
|
||||||
|
|
||||||
|
timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix()
|
||||||
|
|
||||||
|
username := fmt.Sprint(timeAuth)
|
||||||
|
|
||||||
|
_, err := mac.Write([]byte(username))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorln("Generating turn password failed with error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bytePassword := mac.Sum(nil)
|
||||||
|
password := base64.StdEncoding.EncodeToString(bytePassword)
|
||||||
|
|
||||||
|
return TURNCredentials{
|
||||||
|
Username: username,
|
||||||
|
Password: password,
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
|
||||||
|
if channel, ok := m.cancelMap[peerID]; ok {
|
||||||
|
close(channel)
|
||||||
|
delete(m.cancelMap, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelRefresh cancels scheduled peer credentials refresh
|
||||||
|
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.cancel(peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
|
||||||
|
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
|
||||||
|
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.cancel(peerID)
|
||||||
|
cancel := make(chan struct{}, 1)
|
||||||
|
m.cancelMap[peerID] = cancel
|
||||||
|
log.Debugf("starting turn refresh for %s", peerID)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
|
||||||
|
ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-cancel:
|
||||||
|
log.Debugf("stopping turn refresh for %s", peerID)
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
c := m.GenerateCredentials()
|
||||||
|
var turns []*proto.ProtectedHostConfig
|
||||||
|
for _, host := range m.config.Turns {
|
||||||
|
turns = append(turns, &proto.ProtectedHostConfig{
|
||||||
|
HostConfig: &proto.HostConfig{
|
||||||
|
Uri: host.URI,
|
||||||
|
Protocol: ToResponseProto(host.Proto),
|
||||||
|
},
|
||||||
|
User: c.Username,
|
||||||
|
Password: c.Password,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
update := &proto.SyncResponse{
|
||||||
|
WiretrusteeConfig: &proto.WiretrusteeConfig{
|
||||||
|
Turns: turns,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
log.Debugf("sending new TURN credentials to peer %s", peerID)
|
||||||
|
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
b64 "encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"hash/crc32"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
b "github.com/hashicorp/go-secure-stdlib/base62"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/base62"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PATPrefix is the globally used, 4 char prefix for personal access tokens
|
||||||
|
PATPrefix = "nbp_"
|
||||||
|
// PATSecretLength number of characters used for the secret inside the token
|
||||||
|
PATSecretLength = 30
|
||||||
|
// PATChecksumLength number of characters used for the encoded checksum of the secret inside the token
|
||||||
|
PATChecksumLength = 6
|
||||||
|
// PATLength total number of characters used for the token
|
||||||
|
PATLength = 40
|
||||||
|
)
|
||||||
|
|
||||||
|
// PersonalAccessToken holds all information about a PAT including a hashed version of it for verification
|
||||||
|
type PersonalAccessToken struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
// User is a reference to Account that this object belongs
|
||||||
|
UserID string `gorm:"index"`
|
||||||
|
Name string
|
||||||
|
HashedToken string
|
||||||
|
ExpirationDate time.Time
|
||||||
|
// scope could be added in future
|
||||||
|
CreatedBy string
|
||||||
|
CreatedAt time.Time
|
||||||
|
LastUsed time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PersonalAccessToken) Copy() *PersonalAccessToken {
|
||||||
|
return &PersonalAccessToken{
|
||||||
|
ID: t.ID,
|
||||||
|
Name: t.Name,
|
||||||
|
HashedToken: t.HashedToken,
|
||||||
|
ExpirationDate: t.ExpirationDate,
|
||||||
|
CreatedBy: t.CreatedBy,
|
||||||
|
CreatedAt: t.CreatedAt,
|
||||||
|
LastUsed: t.LastUsed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it
|
||||||
|
type PersonalAccessTokenGenerated struct {
|
||||||
|
PlainToken string
|
||||||
|
PersonalAccessToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
|
||||||
|
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
|
||||||
|
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||||
|
hashedToken, plainToken, err := generateNewToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
currentTime := time.Now()
|
||||||
|
return &PersonalAccessTokenGenerated{
|
||||||
|
PersonalAccessToken: PersonalAccessToken{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Name: name,
|
||||||
|
HashedToken: hashedToken,
|
||||||
|
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),
|
||||||
|
CreatedBy: createdBy,
|
||||||
|
CreatedAt: currentTime,
|
||||||
|
LastUsed: time.Time{},
|
||||||
|
},
|
||||||
|
PlainToken: plainToken,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateNewToken() (string, string, error) {
|
||||||
|
secret, err := b.Random(PATSecretLength)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
checksum := crc32.ChecksumIEEE([]byte(secret))
|
||||||
|
encodedChecksum := base62.Encode(checksum)
|
||||||
|
paddedChecksum := fmt.Sprintf("%06s", encodedChecksum)
|
||||||
|
plainToken := PATPrefix + secret + paddedChecksum
|
||||||
|
hashedToken := sha256.Sum256([]byte(plainToken))
|
||||||
|
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||||
|
return encodedHashedToken, plainToken, nil
|
||||||
|
}
|
||||||
203
management/server/management_refactor/server/users/user.go
Normal file
203
management/server/management_refactor/server/users/user.go
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
package users
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
UserRoleAdmin UserRole = "admin"
|
||||||
|
UserRoleUser UserRole = "user"
|
||||||
|
UserRoleUnknown UserRole = "unknown"
|
||||||
|
|
||||||
|
UserStatusActive UserStatus = "active"
|
||||||
|
UserStatusDisabled UserStatus = "disabled"
|
||||||
|
UserStatusInvited UserStatus = "invited"
|
||||||
|
|
||||||
|
UserIssuedAPI = "api"
|
||||||
|
UserIssuedIntegration = "integration"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
Status string `json:"-"`
|
||||||
|
IsServiceUser bool `json:"is_service_user"`
|
||||||
|
IsBlocked bool `json:"is_blocked"`
|
||||||
|
NonDeletable bool `json:"non_deletable"`
|
||||||
|
LastLogin time.Time `json:"last_login"`
|
||||||
|
Issued string `json:"issued"`
|
||||||
|
IntegrationReference IntegrationReference `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
|
||||||
|
func StrRoleToUserRole(strRole string) UserRole {
|
||||||
|
switch strings.ToLower(strRole) {
|
||||||
|
case "admin":
|
||||||
|
return UserRoleAdmin
|
||||||
|
case "user":
|
||||||
|
return UserRoleUser
|
||||||
|
default:
|
||||||
|
return UserRoleUnknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserStatus is the status of a User
|
||||||
|
type UserStatus string
|
||||||
|
|
||||||
|
// UserRole is the role of a User
|
||||||
|
type UserRole string
|
||||||
|
|
||||||
|
// IntegrationReference holds the reference to a particular integration
|
||||||
|
type IntegrationReference struct {
|
||||||
|
ID int
|
||||||
|
IntegrationType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ir IntegrationReference) String() string {
|
||||||
|
return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ir IntegrationReference) CacheKey(path ...string) string {
|
||||||
|
if len(path) == 0 {
|
||||||
|
return ir.String()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// User represents a user of the system
|
||||||
|
type User struct {
|
||||||
|
Id string `gorm:"primaryKey"`
|
||||||
|
// AccountID is a reference to Account that this object belongs
|
||||||
|
AccountID string `json:"-" gorm:"index"`
|
||||||
|
Role UserRole
|
||||||
|
IsServiceUser bool
|
||||||
|
// NonDeletable indicates whether the service user can be deleted
|
||||||
|
NonDeletable bool
|
||||||
|
// ServiceUserName is only set if IsServiceUser is true
|
||||||
|
ServiceUserName string
|
||||||
|
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||||
|
AutoGroups []string `gorm:"serializer:json"`
|
||||||
|
PATs map[string]*PersonalAccessToken `gorm:"-"`
|
||||||
|
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"`
|
||||||
|
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||||
|
Blocked bool
|
||||||
|
// LastLogin is the last time the user logged in to IdP
|
||||||
|
LastLogin time.Time
|
||||||
|
|
||||||
|
// Issued of the user
|
||||||
|
Issued string `gorm:"default:api"`
|
||||||
|
|
||||||
|
IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBlocked returns true if the user is blocked, false otherwise
|
||||||
|
func (u *User) IsBlocked() bool {
|
||||||
|
return u.Blocked
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
|
||||||
|
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAdmin returns true if the user is an admin, false otherwise
|
||||||
|
func (u *User) IsAdmin() bool {
|
||||||
|
return u.Role == UserRoleAdmin
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToUserInfo converts a User object to a UserInfo object.
|
||||||
|
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||||
|
autoGroups := u.AutoGroups
|
||||||
|
if autoGroups == nil {
|
||||||
|
autoGroups = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if userData == nil {
|
||||||
|
return &UserInfo{
|
||||||
|
ID: u.Id,
|
||||||
|
Email: "",
|
||||||
|
Name: u.ServiceUserName,
|
||||||
|
Role: string(u.Role),
|
||||||
|
AutoGroups: u.AutoGroups,
|
||||||
|
Status: string(UserStatusActive),
|
||||||
|
IsServiceUser: u.IsServiceUser,
|
||||||
|
IsBlocked: u.Blocked,
|
||||||
|
LastLogin: u.LastLogin,
|
||||||
|
Issued: u.Issued,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if userData.ID != u.Id {
|
||||||
|
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
userStatus := UserStatusActive
|
||||||
|
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
|
||||||
|
userStatus = UserStatusInvited
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserInfo{
|
||||||
|
ID: u.Id,
|
||||||
|
Email: userData.Email,
|
||||||
|
Name: userData.Name,
|
||||||
|
Role: string(u.Role),
|
||||||
|
AutoGroups: autoGroups,
|
||||||
|
Status: string(userStatus),
|
||||||
|
IsServiceUser: u.IsServiceUser,
|
||||||
|
IsBlocked: u.Blocked,
|
||||||
|
LastLogin: u.LastLogin,
|
||||||
|
Issued: u.Issued,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the user
|
||||||
|
func (u *User) Copy() *User {
|
||||||
|
autoGroups := make([]string, len(u.AutoGroups))
|
||||||
|
copy(autoGroups, u.AutoGroups)
|
||||||
|
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
|
||||||
|
for k, v := range u.PATs {
|
||||||
|
pats[k] = v.Copy()
|
||||||
|
}
|
||||||
|
return &User{
|
||||||
|
Id: u.Id,
|
||||||
|
AccountID: u.AccountID,
|
||||||
|
Role: u.Role,
|
||||||
|
AutoGroups: autoGroups,
|
||||||
|
IsServiceUser: u.IsServiceUser,
|
||||||
|
NonDeletable: u.NonDeletable,
|
||||||
|
ServiceUserName: u.ServiceUserName,
|
||||||
|
PATs: pats,
|
||||||
|
Blocked: u.Blocked,
|
||||||
|
LastLogin: u.LastLogin,
|
||||||
|
Issued: u.Issued,
|
||||||
|
IntegrationReference: u.IntegrationReference,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUser creates a new user
|
||||||
|
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
|
||||||
|
return &User{
|
||||||
|
Id: id,
|
||||||
|
Role: role,
|
||||||
|
IsServiceUser: isServiceUser,
|
||||||
|
NonDeletable: nonDeletable,
|
||||||
|
ServiceUserName: serviceUserName,
|
||||||
|
AutoGroups: autoGroups,
|
||||||
|
Issued: issued,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegularUser creates a new user with role UserRoleUser
|
||||||
|
func NewRegularUser(id string) *User {
|
||||||
|
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||||
|
func NewAdminUser(id string) *User {
|
||||||
|
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
|
||||||
|
}
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package users
|
||||||
|
|
||||||
|
type UserManager interface {
|
||||||
|
GetUser(userID string) (User, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultUserManager struct {
|
||||||
|
repository UserRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserManager(repository UserRepository) *DefaultUserManager {
|
||||||
|
return &DefaultUserManager{
|
||||||
|
repository: repository,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (um *DefaultUserManager) GetUser(userID string) (User, error) {
|
||||||
|
return um.repository.findUserByID(userID)
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
package users
|
||||||
|
|
||||||
|
type UserRepository interface {
|
||||||
|
findUserByID(userID string) (User, error)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user