[management] Add API of new network concept (#3012)

This commit is contained in:
Pascal Fischer
2024-12-11 12:58:45 +01:00
committed by GitHub
parent 9f859a240e
commit 60ee31df3e
92 changed files with 5320 additions and 3562 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
package types
// DNSSettings defines dns settings at the account level
type DNSSettings struct {
// DisabledManagementGroups groups whose DNS management is disabled
DisabledManagementGroups []string `gorm:"serializer:json"`
}
// Copy returns a copy of the DNS settings
func (d DNSSettings) Copy() DNSSettings {
settings := DNSSettings{
DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)),
}
copy(settings.DisabledManagementGroups, d.DisabledManagementGroups)
return settings
}

View File

@@ -0,0 +1,129 @@
package types
import (
"context"
"fmt"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
nbroute "github.com/netbirdio/netbird/route"
)
const (
FirewallRuleDirectionIN = 0
FirewallRuleDirectionOUT = 1
)
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
PeerIP string
// Direction of the traffic
Direction int
// Action of the traffic
Action string
// Protocol of the traffic
Protocol string
// Port of the traffic
Port string
}
// generateRouteFirewallRules generates a list of firewall rules for a given route.
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
rulesExists := make(map[string]struct{})
rules := make([]*RouteFirewallRule, 0)
sourceRanges := make([]string, 0, len(groupPeers))
for _, peer := range groupPeers {
if peer == nil {
continue
}
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP))
}
baseRule := RouteFirewallRule{
SourceRanges: sourceRanges,
Action: string(rule.Action),
Destination: route.Network.String(),
Protocol: string(rule.Protocol),
IsDynamic: route.IsDynamic(),
}
// generate rule for port range
if len(rule.Ports) == 0 {
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
} else {
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
}
// TODO: generate IPv6 rules for dynamic routes
return rules
}
// generateRulesForPeer generates rules for a given peer based on ports and port ranges.
func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
rules := make([]*RouteFirewallRule, 0)
ruleIDBase := generateRuleIDBase(rule, baseRule)
if len(rule.Ports) == 0 {
if len(rule.PortRanges) == 0 {
if _, ok := rulesExists[ruleIDBase]; !ok {
rulesExists[ruleIDBase] = struct{}{}
rules = append(rules, &baseRule)
}
} else {
for _, portRange := range rule.PortRanges {
ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End)
if _, ok := rulesExists[ruleID]; !ok {
rulesExists[ruleID] = struct{}{}
pr := baseRule
pr.PortRange = portRange
rules = append(rules, &pr)
}
}
}
return rules
}
return rules
}
// generateRulesWithPorts generates rules when specific ports are provided.
func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
rules := make([]*RouteFirewallRule, 0)
ruleIDBase := generateRuleIDBase(rule, baseRule)
for _, port := range rule.Ports {
ruleID := ruleIDBase + port
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
pr := baseRule
p, err := strconv.ParseUint(port, 10, 16)
if err != nil {
log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID)
continue
}
pr.Port = uint16(p)
rules = append(rules, &pr)
}
return rules
}
// generateRuleIDBase generates the base rule ID for checking duplicates.
func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string {
return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(FirewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action
}

View File

@@ -0,0 +1,149 @@
package types
import (
"math/rand"
"net"
"sync"
"time"
"github.com/c-robinson/iplib"
"github.com/rs/xid"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
)
const (
// SubnetSize is a size of the subnet of the global network, e.g. 100.77.0.0/16
SubnetSize = 16
// NetSize is a global network size 100.64.0.0/10
NetSize = 10
// AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.64.30.1/32)
AllowedIPsFormat = "%s/32"
)
type NetworkMap struct {
Peers []*nbpeer.Peer
Network *Network
Routes []*route.Route
DNSConfig nbdns.Config
OfflinePeers []*nbpeer.Peer
FirewallRules []*FirewallRule
RoutesFirewallRules []*RouteFirewallRule
}
type Network struct {
Identifier string `json:"id"`
Net net.IPNet `gorm:"serializer:json"`
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64
Mu sync.Mutex `json:"-" gorm:"-"`
}
// NewNetwork creates a new Network initializing it with a Serial=0
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
func NewNetwork() *Network {
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
sub, _ := n.Subnet(SubnetSize)
s := rand.NewSource(time.Now().Unix())
r := rand.New(s)
intn := r.Intn(len(sub))
return &Network{
Identifier: xid.New().String(),
Net: sub[intn].IPNet,
Dns: "",
Serial: 0}
}
// IncSerial increments Serial by 1 reflecting that the network state has been changed
func (n *Network) IncSerial() {
n.Mu.Lock()
defer n.Mu.Unlock()
n.Serial++
}
// CurrentSerial returns the Network.Serial of the network (latest state id)
func (n *Network) CurrentSerial() uint64 {
n.Mu.Lock()
defer n.Mu.Unlock()
return n.Serial
}
func (n *Network) Copy() *Network {
return &Network{
Identifier: n.Identifier,
Net: n.Net,
Dns: n.Dns,
Serial: n.Serial,
}
}
// AllocatePeerIP pics an available IP from an net.IPNet.
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
takenIPMap := make(map[string]struct{})
takenIPMap[ipNet.IP.String()] = struct{}{}
for _, ip := range takenIps {
takenIPMap[ip.String()] = struct{}{}
}
ips, _ := generateIPs(&ipNet, takenIPMap)
if len(ips) == 0 {
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
}
// pick a random IP
s := rand.NewSource(time.Now().Unix())
r := rand.New(s)
intn := r.Intn(len(ips))
return ips[intn], nil
}
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list
func generateIPs(ipNet *net.IPNet, exclusions map[string]struct{}) ([]net.IP, int) {
var ips []net.IP
for ip := ipNet.IP.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
if _, ok := exclusions[ip.String()]; !ok && ip[3] != 0 {
ips = append(ips, copyIP(ip))
}
}
// remove network address, broadcast and Fake DNS resolver address
lenIPs := len(ips)
switch {
case lenIPs < 2:
return ips, lenIPs
case lenIPs < 3:
return ips[1 : len(ips)-1], lenIPs - 2
default:
return ips[1 : len(ips)-2], lenIPs - 3
}
}
func copyIP(ip net.IP) net.IP {
dup := make(net.IP, len(ip))
copy(dup, ip)
return dup
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}

View File

@@ -0,0 +1,51 @@
package types
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewNetwork(t *testing.T) {
network := NewNetwork()
// generated net should be a subnet of a larger 100.64.0.0/10 net
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}}
assert.Equal(t, ipNet.Contains(network.Net.IP), true)
}
func TestAllocatePeerIP(t *testing.T) {
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
var ips []net.IP
for i := 0; i < 252; i++ {
ip, err := AllocatePeerIP(ipNet, ips)
if err != nil {
t.Fatal(err)
}
ips = append(ips, ip)
}
assert.Len(t, ips, 252)
uniq := make(map[string]struct{})
for _, ip := range ips {
if _, ok := uniq[ip.String()]; !ok {
uniq[ip.String()] = struct{}{}
} else {
t.Errorf("found duplicate IP %s", ip.String())
}
}
}
func TestGenerateIPs(t *testing.T) {
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}}
ips, ipsLen := generateIPs(&ipNet, map[string]struct{}{"100.64.0.0": {}})
if ipsLen != 252 {
t.Errorf("expected 252 ips, got %d", len(ips))
return
}
if ips[len(ips)-1].String() != "100.64.0.253" {
t.Errorf("expected last ip to be: 100.64.0.253, got %s", ips[len(ips)-1].String())
}
}

View File

@@ -0,0 +1,95 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/crc32"
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
)
const (
// PATPrefix is the globally used, 4 char prefix for personal access tokens
PATPrefix = "nbp_"
// PATSecretLength number of characters used for the secret inside the token
PATSecretLength = 30
// PATChecksumLength number of characters used for the encoded checksum of the secret inside the token
PATChecksumLength = 6
// PATLength total number of characters used for the token
PATLength = 40
)
// PersonalAccessToken holds all information about a PAT including a hashed version of it for verification
type PersonalAccessToken struct {
ID string `gorm:"primaryKey"`
// User is a reference to Account that this object belongs
UserID string `gorm:"index"`
Name string
HashedToken string
ExpirationDate time.Time
// scope could be added in future
CreatedBy string
CreatedAt time.Time
LastUsed time.Time
}
func (t *PersonalAccessToken) Copy() *PersonalAccessToken {
return &PersonalAccessToken{
ID: t.ID,
Name: t.Name,
HashedToken: t.HashedToken,
ExpirationDate: t.ExpirationDate,
CreatedBy: t.CreatedBy,
CreatedAt: t.CreatedAt,
LastUsed: t.LastUsed,
}
}
// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it
type PersonalAccessTokenGenerated struct {
PlainToken string
PersonalAccessToken
}
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateNewToken()
if err != nil {
return nil, err
}
currentTime := time.Now()
return &PersonalAccessTokenGenerated{
PersonalAccessToken: PersonalAccessToken{
ID: xid.New().String(),
Name: name,
HashedToken: hashedToken,
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),
CreatedBy: createdBy,
CreatedAt: currentTime,
LastUsed: time.Time{},
},
PlainToken: plainToken,
}, nil
}
func generateNewToken() (string, string, error) {
secret, err := b.Random(PATSecretLength)
if err != nil {
return "", "", err
}
checksum := crc32.ChecksumIEEE([]byte(secret))
encodedChecksum := base62.Encode(checksum)
paddedChecksum := fmt.Sprintf("%06s", encodedChecksum)
plainToken := PATPrefix + secret + paddedChecksum
hashedToken := sha256.Sum256([]byte(plainToken))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
return encodedHashedToken, plainToken, nil
}

View File

@@ -0,0 +1,46 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"hash/crc32"
"math/big"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/base62"
)
func TestPAT_GenerateToken_Hashing(t *testing.T) {
hashedToken, plainToken, _ := generateNewToken()
expectedToken := sha256.Sum256([]byte(plainToken))
encodedExpectedToken := b64.StdEncoding.EncodeToString(expectedToken[:])
assert.Equal(t, hashedToken, encodedExpectedToken)
}
func TestPAT_GenerateToken_Prefix(t *testing.T) {
_, plainToken, _ := generateNewToken()
fourCharPrefix := plainToken[:4]
assert.Equal(t, PATPrefix, fourCharPrefix)
}
func TestPAT_GenerateToken_Checksum(t *testing.T) {
_, plainToken, _ := generateNewToken()
tokenWithoutPrefix := strings.Split(plainToken, "_")[1]
if len(tokenWithoutPrefix) != 36 {
t.Fatal("Token has wrong length")
}
secret := tokenWithoutPrefix[:len(tokenWithoutPrefix)-6]
tokenCheckSum := tokenWithoutPrefix[len(tokenWithoutPrefix)-6:]
var i big.Int
i.SetString(secret, 62)
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
actualChecksum, err := base62.Decode(tokenCheckSum)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, expectedChecksum, actualChecksum)
}

View File

@@ -0,0 +1,116 @@
package types
const (
// PolicyTrafficActionAccept indicates that the traffic is accepted
PolicyTrafficActionAccept = PolicyTrafficActionType("accept")
// PolicyTrafficActionDrop indicates that the traffic is dropped
PolicyTrafficActionDrop = PolicyTrafficActionType("drop")
)
const (
// PolicyRuleProtocolALL type of traffic
PolicyRuleProtocolALL = PolicyRuleProtocolType("all")
// PolicyRuleProtocolTCP type of traffic
PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp")
// PolicyRuleProtocolUDP type of traffic
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
)
const (
// PolicyRuleFlowDirect allows traffic from source to destination
PolicyRuleFlowDirect = PolicyRuleDirection("direct")
// PolicyRuleFlowBidirect allows traffic to both directions
PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect")
)
const (
// DefaultRuleName is a name for the Default rule that is created for every account
DefaultRuleName = "Default"
// DefaultRuleDescription is a description for the Default rule that is created for every account
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
// DefaultPolicyName is a name for the Default policy that is created for every account
DefaultPolicyName = "Default"
// DefaultPolicyDescription is a description for the Default policy that is created for every account
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
)
// PolicyUpdateOperation operation object with type and values to be applied
type PolicyUpdateOperation struct {
Type PolicyUpdateOperationType
Values []string
}
// Policy of the Rego query
type Policy struct {
// ID of the policy'
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name of the Policy
Name string
// Description of the policy visible in the UI
Description string
// Enabled status of the policy
Enabled bool
// Rules of the policy
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
// SourcePostureChecks are ID references to Posture checks for policy source groups
SourcePostureChecks []string `gorm:"serializer:json"`
}
// Copy returns a copy of the policy.
func (p *Policy) Copy() *Policy {
c := &Policy{
ID: p.ID,
AccountID: p.AccountID,
Name: p.Name,
Description: p.Description,
Enabled: p.Enabled,
Rules: make([]*PolicyRule, len(p.Rules)),
SourcePostureChecks: make([]string, len(p.SourcePostureChecks)),
}
for i, r := range p.Rules {
c.Rules[i] = r.Copy()
}
copy(c.SourcePostureChecks, p.SourcePostureChecks)
return c
}
// EventMeta returns activity event meta related to this policy
func (p *Policy) EventMeta() map[string]any {
return map[string]any{"name": p.Name}
}
// UpgradeAndFix different version of policies to latest version
func (p *Policy) UpgradeAndFix() {
for _, r := range p.Rules {
// start migrate from version v0.20.3
if r.Protocol == "" {
r.Protocol = PolicyRuleProtocolALL
}
if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional {
r.Bidirectional = true
}
// -- v0.20.4
}
}
// RuleGroups returns a list of all groups referenced in the policy's rules,
// including sources and destinations.
func (p *Policy) RuleGroups() []string {
groups := make([]string, 0)
for _, rule := range p.Rules {
groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...)
}
return groups
}

View File

@@ -0,0 +1,81 @@
package types
// PolicyUpdateOperationType operation type
type PolicyUpdateOperationType int
// PolicyTrafficActionType action type for the firewall
type PolicyTrafficActionType string
// PolicyRuleProtocolType type of traffic
type PolicyRuleProtocolType string
// PolicyRuleDirection direction of traffic
type PolicyRuleDirection string
// RulePortRange represents a range of ports for a firewall rule.
type RulePortRange struct {
Start uint16
End uint16
}
// PolicyRule is the metadata of the policy
type PolicyRule struct {
// ID of the policy rule
ID string `gorm:"primaryKey"`
// PolicyID is a reference to Policy that this object belongs
PolicyID string `json:"-" gorm:"index"`
// Name of the rule visible in the UI
Name string
// Description of the rule visible in the UI
Description string
// Enabled status of rule in the system
Enabled bool
// Action policy accept or drops packets
Action PolicyTrafficActionType
// Destinations policy destination groups
Destinations []string `gorm:"serializer:json"`
// Sources policy source groups
Sources []string `gorm:"serializer:json"`
// Bidirectional define if the rule is applicable in both directions, sources, and destinations
Bidirectional bool
// Protocol type of the traffic
Protocol PolicyRuleProtocolType
// Ports or it ranges list
Ports []string `gorm:"serializer:json"`
// PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"`
}
// Copy returns a copy of a policy rule
func (pm *PolicyRule) Copy() *PolicyRule {
rule := &PolicyRule{
ID: pm.ID,
PolicyID: pm.PolicyID,
Name: pm.Name,
Description: pm.Description,
Enabled: pm.Enabled,
Action: pm.Action,
Destinations: make([]string, len(pm.Destinations)),
Sources: make([]string, len(pm.Sources)),
Bidirectional: pm.Bidirectional,
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges)
return rule
}

View File

@@ -0,0 +1,25 @@
package types
// RouteFirewallRule a firewall rule applicable for a routed network.
type RouteFirewallRule struct {
// SourceRanges IP ranges of the routing peers.
SourceRanges []string
// Action of the traffic when the rule is applicable
Action string
// Destination a network prefix for the routed traffic
Destination string
// Protocol of the traffic
Protocol string
// Port of the traffic
Port uint16
// PortRange represents the range of ports for a firewall rule
PortRange RulePortRange
// isDynamic indicates whether the rule is for DNS routing
IsDynamic bool
}

View File

@@ -0,0 +1,63 @@
package types
import (
"time"
"github.com/netbirdio/netbird/management/server/account"
)
// Settings represents Account settings structure that can be modified via API and Dashboard
type Settings struct {
// PeerLoginExpirationEnabled globally enables or disables peer login expiration
PeerLoginExpirationEnabled bool
// PeerLoginExpiration is a setting that indicates when peer login expires.
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration
// PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration
PeerInactivityExpirationEnabled bool
// PeerInactivityExpiration is a setting that indicates when peer inactivity expires.
// Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true.
PeerInactivityExpiration time.Duration
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
RegularUsersViewBlocked bool
// GroupsPropagationEnabled allows to propagate auto groups from the user to the peer
GroupsPropagationEnabled bool
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
// and add it to account groups.
JWTGroupsEnabled bool
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string
// JWTAllowGroups list of groups to which users are allowed access
JWTAllowGroups []string `gorm:"serializer:json"`
// Extra is a dictionary of Account settings
Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
}
// Copy copies the Settings struct
func (s *Settings) Copy() *Settings {
settings := &Settings{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpiration: s.PeerLoginExpiration,
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
JWTAllowGroups: s.JWTAllowGroups,
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: s.PeerInactivityExpiration,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()
}
return settings
}

View File

@@ -0,0 +1,181 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"hash/fnv"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/google/uuid"
)
const (
// SetupKeyReusable is a multi-use key (can be used for multiple machines)
SetupKeyReusable SetupKeyType = "reusable"
// SetupKeyOneOff is a single use key (can be used only once)
SetupKeyOneOff SetupKeyType = "one-off"
// DefaultSetupKeyDuration = 1 month
DefaultSetupKeyDuration = 24 * 30 * time.Hour
// DefaultSetupKeyName is a default name of the default setup key
DefaultSetupKeyName = "Default key"
// SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key
SetupKeyUnlimitedUsage = 0
)
// SetupKeyType is the type of setup key
type SetupKeyType string
// SetupKey represents a pre-authorized key used to register machines (peers)
type SetupKey struct {
Id string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Key string
KeySecret string
Name string
Type SetupKeyType
CreatedAt time.Time
ExpiresAt time.Time
UpdatedAt time.Time `gorm:"autoUpdateTime:false"`
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
Revoked bool
// UsedTimes indicates how many times the key was used
UsedTimes int
// LastUsed last time the key was used for peer registration
LastUsed time.Time
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
AutoGroups []string `gorm:"serializer:json"`
// UsageLimit indicates the number of times this key can be used to enroll a machine.
// The value of 0 indicates the unlimited usage.
UsageLimit int
// Ephemeral indicate if the peers will be ephemeral or not
Ephemeral bool
}
// Copy copies SetupKey to a new object
func (key *SetupKey) Copy() *SetupKey {
autoGroups := make([]string, len(key.AutoGroups))
copy(autoGroups, key.AutoGroups)
if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt
}
return &SetupKey{
Id: key.Id,
AccountID: key.AccountID,
Key: key.Key,
KeySecret: key.KeySecret,
Name: key.Name,
Type: key.Type,
CreatedAt: key.CreatedAt,
ExpiresAt: key.ExpiresAt,
UpdatedAt: key.UpdatedAt,
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
AutoGroups: autoGroups,
UsageLimit: key.UsageLimit,
Ephemeral: key.Ephemeral,
}
}
// EventMeta returns activity event meta related to the setup key
func (key *SetupKey) EventMeta() map[string]any {
return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret}
}
// HiddenKey returns the Key value hidden with "*" and a 5 character prefix.
// E.g., "831F6*******************************"
func HiddenKey(key string, length int) string {
prefix := key[0:5]
if length > utf8.RuneCountInString(key) {
length = utf8.RuneCountInString(key) - len(prefix)
}
return prefix + strings.Repeat("*", length)
}
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
func (key *SetupKey) IncrementUsage() *SetupKey {
c := key.Copy()
c.UsedTimes++
c.LastUsed = time.Now().UTC()
return c
}
// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to
func (key *SetupKey) IsValid() bool {
return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed()
}
// IsRevoked if key was revoked
func (key *SetupKey) IsRevoked() bool {
return key.Revoked
}
// IsExpired if key was expired
func (key *SetupKey) IsExpired() bool {
if key.ExpiresAt.IsZero() {
return false
}
return time.Now().After(key.ExpiresAt)
}
// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage.
func (key *SetupKey) IsOverUsed() bool {
limit := key.UsageLimit
if key.Type == SetupKeyOneOff {
limit = 1
}
return limit > 0 && key.UsedTimes >= limit
}
// GenerateSetupKey generates a new setup key
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string,
usageLimit int, ephemeral bool) (*SetupKey, string) {
key := strings.ToUpper(uuid.New().String())
limit := usageLimit
if t == SetupKeyOneOff {
limit = 1
}
expiresAt := time.Time{}
if validFor != 0 {
expiresAt = time.Now().UTC().Add(validFor)
}
hashedKey := sha256.Sum256([]byte(key))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
return &SetupKey{
Id: strconv.Itoa(int(Hash(key))),
Key: encodedHashedKey,
KeySecret: HiddenKey(key, 4),
Name: name,
Type: t,
CreatedAt: time.Now().UTC(),
ExpiresAt: expiresAt,
UpdatedAt: time.Now().UTC(),
Revoked: false,
UsedTimes: 0,
AutoGroups: autoGroups,
UsageLimit: limit,
Ephemeral: ephemeral,
}, key
}
// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration
func GenerateDefaultSetupKey() (*SetupKey, string) {
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{},
SetupKeyUnlimitedUsage, false)
}
func Hash(s string) uint32 {
h := fnv.New32a()
_, err := h.Write([]byte(s))
if err != nil {
panic(err)
}
return h.Sum32()
}

View File

@@ -0,0 +1,231 @@
package types
import (
"fmt"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
)
const (
UserRoleOwner UserRole = "owner"
UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user"
UserRoleUnknown UserRole = "unknown"
UserRoleBillingAdmin UserRole = "billing_admin"
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
UserStatusInvited UserStatus = "invited"
UserIssuedAPI = "api"
UserIssuedIntegration = "integration"
)
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
func StrRoleToUserRole(strRole string) UserRole {
switch strings.ToLower(strRole) {
case "owner":
return UserRoleOwner
case "admin":
return UserRoleAdmin
case "user":
return UserRoleUser
case "billing_admin":
return UserRoleBillingAdmin
default:
return UserRoleUnknown
}
}
// UserStatus is the status of a User
type UserStatus string
// UserRole is the role of a User
type UserRole string
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
Permissions UserPermissions `json:"permissions"`
}
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
// User represents a user of the system
type User struct {
Id string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Role UserRole
IsServiceUser bool
// NonDeletable indicates whether the service user can be deleted
NonDeletable bool
// ServiceUserName is only set if IsServiceUser is true
ServiceUserName string
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string `gorm:"serializer:json"`
PATs map[string]*PersonalAccessToken `gorm:"-"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// LastLogin is the last time the user logged in to IdP
LastLogin time.Time
// CreatedAt records the time the user was created
CreatedAt time.Time
// Issued of the user
Issued string `gorm:"default:api"`
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
// IsBlocked returns true if the user is blocked, false otherwise
func (u *User) IsBlocked() bool {
return u.Blocked
}
func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
}
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
}
// IsAdminOrServiceUser checks if the user has admin power or is a service user.
func (u *User) IsAdminOrServiceUser() bool {
return u.HasAdminPower() || u.IsServiceUser
}
// IsRegularUser checks if the user is a regular user.
func (u *User) IsRegularUser() bool {
return !u.HasAdminPower() && !u.IsServiceUser
}
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
dashboardViewPermissions := "full"
if !u.HasAdminPower() {
dashboardViewPermissions = "limited"
if settings.RegularUsersViewBlocked {
dashboardViewPermissions = "blocked"
}
}
if userData == nil {
return &UserInfo{
ID: u.Id,
Email: "",
Name: u.ServiceUserName,
Role: string(u.Role),
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.LastLogin,
Issued: u.Issued,
Permissions: UserPermissions{
DashboardView: dashboardViewPermissions,
},
}, nil
}
if userData.ID != u.Id {
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
}
userStatus := UserStatusActive
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
userStatus = UserStatusInvited
}
return &UserInfo{
ID: u.Id,
Email: userData.Email,
Name: userData.Name,
Role: string(u.Role),
AutoGroups: autoGroups,
Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.LastLogin,
Issued: u.Issued,
Permissions: UserPermissions{
DashboardView: dashboardViewPermissions,
},
}, nil
}
// Copy the user
func (u *User) Copy() *User {
autoGroups := make([]string, len(u.AutoGroups))
copy(autoGroups, u.AutoGroups)
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
for k, v := range u.PATs {
pats[k] = v.Copy()
}
return &User{
Id: u.Id,
AccountID: u.AccountID,
Role: u.Role,
AutoGroups: autoGroups,
IsServiceUser: u.IsServiceUser,
NonDeletable: u.NonDeletable,
ServiceUserName: u.ServiceUserName,
PATs: pats,
Blocked: u.Blocked,
LastLogin: u.LastLogin,
CreatedAt: u.CreatedAt,
Issued: u.Issued,
IntegrationReference: u.IntegrationReference,
}
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
return &User{
Id: id,
Role: role,
IsServiceUser: isServiceUser,
NonDeletable: nonDeletable,
ServiceUserName: serviceUserName,
AutoGroups: autoGroups,
Issued: issued,
CreatedAt: time.Now().UTC(),
}
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
}