mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management] move network map logic into new design (#4774)
This commit is contained in:
352
management/internals/shared/grpc/conversion.go
Normal file
352
management/internals/shared/grpc/conversion.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var stuns []*proto.HostConfig
|
||||
for _, stun := range config.Stuns {
|
||||
stuns = append(stuns, &proto.HostConfig{
|
||||
Uri: stun.URI,
|
||||
Protocol: ToResponseProto(stun.Proto),
|
||||
})
|
||||
}
|
||||
|
||||
var turns []*proto.ProtectedHostConfig
|
||||
if config.TURNConfig != nil {
|
||||
for _, turn := range config.TURNConfig.Turns {
|
||||
var username string
|
||||
var password string
|
||||
if turnCredentials != nil {
|
||||
username = turnCredentials.Payload
|
||||
password = turnCredentials.Signature
|
||||
} else {
|
||||
username = turn.Username
|
||||
password = turn.Password
|
||||
}
|
||||
turns = append(turns, &proto.ProtectedHostConfig{
|
||||
HostConfig: &proto.HostConfig{
|
||||
Uri: turn.URI,
|
||||
Protocol: ToResponseProto(turn.Proto),
|
||||
},
|
||||
User: username,
|
||||
Password: password,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var relayCfg *proto.RelayConfig
|
||||
if config.Relay != nil && len(config.Relay.Addresses) > 0 {
|
||||
relayCfg = &proto.RelayConfig{
|
||||
Urls: config.Relay.Addresses,
|
||||
}
|
||||
|
||||
if relayToken != nil {
|
||||
relayCfg.TokenPayload = relayToken.Payload
|
||||
relayCfg.TokenSignature = relayToken.Signature
|
||||
}
|
||||
}
|
||||
|
||||
var signalCfg *proto.HostConfig
|
||||
if config.Signal != nil {
|
||||
signalCfg = &proto.HostConfig{
|
||||
Uri: config.Signal.URI,
|
||||
Protocol: ToResponseProto(config.Signal.Proto),
|
||||
}
|
||||
}
|
||||
|
||||
nbConfig := &proto.NetbirdConfig{
|
||||
Stuns: stuns,
|
||||
Turns: turns,
|
||||
Signal: signalCfg,
|
||||
Relay: relayCfg,
|
||||
}
|
||||
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
|
||||
netmask, _ := network.Net.Mask.Size()
|
||||
fqdn := peer.FQDN(dnsName)
|
||||
return &proto.PeerConfig{
|
||||
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
|
||||
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
|
||||
Fqdn: fqdn,
|
||||
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
if networkMap.ForwardingRules != nil {
|
||||
forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
|
||||
for _, rule := range networkMap.ForwardingRules {
|
||||
forwardingRules = append(forwardingRules, rule.ToProto())
|
||||
}
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case nbconfig.UDP:
|
||||
return proto.HostConfig_UDP
|
||||
case nbconfig.DTLS:
|
||||
return proto.HostConfig_DTLS
|
||||
case nbconfig.HTTP:
|
||||
return proto.HostConfig_HTTP
|
||||
case nbconfig.HTTPS:
|
||||
return proto.HostConfig_HTTPS
|
||||
case nbconfig.TCP:
|
||||
return proto.HostConfig_TCP
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||
}
|
||||
}
|
||||
|
||||
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if shouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result[i] = fwRule
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||
func getProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: getProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
PortInfo: getProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoAction converts the action to proto.RuleAction.
|
||||
func getProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
150
management/internals/shared/grpc/conversion_test.go
Normal file
150
management/internals/shared/grpc/conversion_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
var cache cache.DNSConfigCache
|
||||
|
||||
// Create two different configs
|
||||
config1 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group1",
|
||||
Name: "Group 1",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.org",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group2",
|
||||
Name: "Group 2",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||
}
|
||||
|
||||
// Verify that result2 is different from result1 and result3
|
||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000}
|
||||
|
||||
for _, size := range sizes {
|
||||
testData := generateTestData(size)
|
||||
|
||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: make([]nbdns.CustomZone, size),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
config.CustomZones[i] = nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("domain%d.com", i),
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: fmt.Sprintf("record%d", i),
|
||||
Type: 1,
|
||||
Class: "IN",
|
||||
TTL: 3600,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("group%d", i),
|
||||
Primary: i == 0,
|
||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||
SearchDomainsEnabled: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
Port: 53,
|
||||
NSType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
160
management/internals/shared/grpc/loginfilter.go
Normal file
160
management/internals/shared/grpc/loginfilter.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
reconnThreshold = 5 * time.Minute
|
||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
)
|
||||
|
||||
type lfConfig struct {
|
||||
reconnThreshold time.Duration
|
||||
baseBlockDuration time.Duration
|
||||
reconnLimitForBan int
|
||||
metaChangeLimit int
|
||||
}
|
||||
|
||||
func initCfg() *lfConfig {
|
||||
return &lfConfig{
|
||||
reconnThreshold: reconnThreshold,
|
||||
baseBlockDuration: baseBlockDuration,
|
||||
reconnLimitForBan: reconnLimitForBan,
|
||||
metaChangeLimit: metaChangeLimit,
|
||||
}
|
||||
}
|
||||
|
||||
type loginFilter struct {
|
||||
mu sync.RWMutex
|
||||
cfg *lfConfig
|
||||
logged map[string]*peerState
|
||||
}
|
||||
|
||||
type peerState struct {
|
||||
currentHash uint64
|
||||
sessionCounter int
|
||||
sessionStart time.Time
|
||||
lastSeen time.Time
|
||||
isBanned bool
|
||||
banLevel int
|
||||
banExpiresAt time.Time
|
||||
metaChangeCounter int
|
||||
metaChangeWindowStart time.Time
|
||||
}
|
||||
|
||||
func newLoginFilter() *loginFilter {
|
||||
return newLoginFilterWithCfg(initCfg())
|
||||
}
|
||||
|
||||
func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter {
|
||||
return &loginFilter{
|
||||
logged: make(map[string]*peerState),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool {
|
||||
l.mu.RLock()
|
||||
defer func() {
|
||||
l.mu.RUnlock()
|
||||
}()
|
||||
state, ok := l.logged[wgPubKey]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if state.isBanned && time.Now().Before(state.banExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if metaHash != state.currentHash {
|
||||
if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
||||
now := time.Now()
|
||||
l.mu.Lock()
|
||||
defer func() {
|
||||
l.mu.Unlock()
|
||||
}()
|
||||
|
||||
state, ok := l.logged[wgPubKey]
|
||||
|
||||
if !ok {
|
||||
l.logged[wgPubKey] = &peerState{
|
||||
currentHash: metaHash,
|
||||
sessionCounter: 1,
|
||||
sessionStart: now,
|
||||
lastSeen: now,
|
||||
metaChangeWindowStart: now,
|
||||
metaChangeCounter: 1,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if state.isBanned && now.After(state.banExpiresAt) {
|
||||
state.isBanned = false
|
||||
}
|
||||
|
||||
if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) {
|
||||
state.banLevel = 0
|
||||
}
|
||||
|
||||
if metaHash != state.currentHash {
|
||||
if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) {
|
||||
state.metaChangeWindowStart = now
|
||||
state.metaChangeCounter = 1
|
||||
} else {
|
||||
state.metaChangeCounter++
|
||||
}
|
||||
state.currentHash = metaHash
|
||||
state.sessionCounter = 1
|
||||
state.sessionStart = now
|
||||
state.lastSeen = now
|
||||
return
|
||||
}
|
||||
|
||||
state.sessionCounter++
|
||||
if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold {
|
||||
state.isBanned = true
|
||||
state.banLevel++
|
||||
|
||||
backoffFactor := math.Pow(2, float64(state.banLevel-1))
|
||||
duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor)
|
||||
state.banExpiresAt = now.Add(duration)
|
||||
|
||||
state.sessionCounter = 0
|
||||
state.sessionStart = now
|
||||
}
|
||||
state.lastSeen = now
|
||||
}
|
||||
|
||||
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
||||
h := fnv.New64a()
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
macs := uint64(0)
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
for _, r := range na.Mac {
|
||||
macs += uint64(r)
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64() + macs
|
||||
}
|
||||
275
management/internals/shared/grpc/loginfilter_test.go
Normal file
275
management/internals/shared/grpc/loginfilter_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func testAdvancedCfg() *lfConfig {
|
||||
return &lfConfig{
|
||||
reconnThreshold: 50 * time.Millisecond,
|
||||
baseBlockDuration: 100 * time.Millisecond,
|
||||
reconnLimitForBan: 3,
|
||||
metaChangeLimit: 2,
|
||||
}
|
||||
}
|
||||
|
||||
type LoginFilterTestSuite struct {
|
||||
suite.Suite
|
||||
filter *loginFilter
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) SetupTest() {
|
||||
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
|
||||
}
|
||||
|
||||
func TestLoginFilterTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(LoginFilterTestSuite))
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
|
||||
s.True(s.filter.allowLogin(pubKey, meta))
|
||||
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
limit := s.filter.cfg.reconnLimitForBan
|
||||
|
||||
for i := 0; i <= limit; i++ {
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
}
|
||||
|
||||
s.False(s.filter.allowLogin(pubKey, meta))
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
limit := s.filter.cfg.reconnLimitForBan
|
||||
baseBan := s.filter.cfg.baseBlockDuration
|
||||
|
||||
for i := 0; i <= limit; i++ {
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
}
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
s.Equal(1, s.filter.logged[pubKey].banLevel)
|
||||
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
|
||||
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
|
||||
|
||||
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
|
||||
s.filter.logged[pubKey].isBanned = false
|
||||
|
||||
for i := 0; i <= limit; i++ {
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
}
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
s.Equal(2, s.filter.logged[pubKey].banLevel)
|
||||
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
|
||||
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
|
||||
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
|
||||
s.filter.logged[pubKey] = &peerState{
|
||||
isBanned: true,
|
||||
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
|
||||
}
|
||||
|
||||
s.True(s.filter.allowLogin(pubKey, meta))
|
||||
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.False(s.filter.logged[pubKey].isBanned)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
|
||||
s.filter.logged[pubKey] = &peerState{
|
||||
currentHash: meta,
|
||||
banLevel: 3,
|
||||
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
|
||||
}
|
||||
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(0, s.filter.logged[pubKey].banLevel)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
limit := s.filter.cfg.metaChangeLimit
|
||||
|
||||
for i := range limit {
|
||||
s.filter.addLogin(pubKey, uint64(i+1))
|
||||
}
|
||||
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
|
||||
|
||||
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
|
||||
|
||||
s.False(isAllowed, "should block new meta hash after limit is reached")
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta1 := uint64(1)
|
||||
meta2 := uint64(2)
|
||||
meta3 := uint64(3)
|
||||
|
||||
s.filter.addLogin(pubKey, meta1)
|
||||
s.filter.addLogin(pubKey, meta2)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
|
||||
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
|
||||
|
||||
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
|
||||
|
||||
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
|
||||
|
||||
s.filter.addLogin(pubKey, meta3)
|
||||
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
|
||||
}
|
||||
|
||||
func BenchmarkHashingMethods(b *testing.B) {
|
||||
meta := nbpeer.PeerSystemMeta{
|
||||
WtVersion: "1.25.1",
|
||||
OSVersion: "Ubuntu 22.04.3 LTS",
|
||||
KernelVersion: "5.15.0-76-generic",
|
||||
Hostname: "prod-server-database-01",
|
||||
SystemSerialNumber: "PC-1234567890",
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
||||
}
|
||||
pubip := "8.8.8.8"
|
||||
|
||||
var resultString string
|
||||
var resultUint uint64
|
||||
|
||||
b.Run("BuilderString", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = builderString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("FnvHashToString", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = fnvHashToString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultUint = metaHash(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
_ = resultString
|
||||
_ = resultUint
|
||||
}
|
||||
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
h := fnv.New64a()
|
||||
|
||||
if len(meta.NetworkAddresses) != 0 {
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
h.Write([]byte(na.Mac))
|
||||
}
|
||||
}
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return strconv.FormatUint(h.Sum64(), 16)
|
||||
}
|
||||
|
||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
mac := getMacAddress(meta.NetworkAddresses)
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
||||
len(pubip) + len(mac) + 6
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(estimatedSize)
|
||||
|
||||
b.WriteString(meta.WtVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.OSVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.KernelVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.Hostname)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.SystemSerialNumber)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(pubip)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
||||
if len(nas) == 0 {
|
||||
return ""
|
||||
}
|
||||
macs := make([]string, 0, len(nas))
|
||||
for _, na := range nas {
|
||||
macs = append(macs, na.Mac)
|
||||
}
|
||||
return strings.Join(macs, "/")
|
||||
}
|
||||
|
||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
||||
numKeys := 100000
|
||||
pubKeys := make([]string, numKeys)
|
||||
for i := range numKeys {
|
||||
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for pb.Next() {
|
||||
key := pubKeys[r.Intn(numKeys)]
|
||||
meta := r.Uint64()
|
||||
|
||||
if filter.allowLogin(key, meta) {
|
||||
filter.addLogin(key, meta)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
945
management/internals/shared/grpc/server.go
Normal file
945
management/internals/shared/grpc/server.go
Normal file
@@ -0,0 +1,945 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
internalStatus "github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
|
||||
envBlockPeers = "NB_BLOCK_SAME_PEERS"
|
||||
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
|
||||
|
||||
defaultSyncLim = 1000
|
||||
)
|
||||
|
||||
// Server an instance of a Management gRPC API server
|
||||
type Server struct {
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
wgKey wgtypes.Key
|
||||
proto.UnimplementedManagementServiceServer
|
||||
peersUpdateManager network_map.PeersUpdateManager
|
||||
config *nbconfig.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager ephemeral.Manager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
loginFilter *loginFilter
|
||||
|
||||
networkMapController network_map.Controller
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
func NewServer(
|
||||
config *nbconfig.Config,
|
||||
accountManager account.Manager,
|
||||
settingsManager settings.Manager,
|
||||
peersUpdateManager network_map.PeersUpdateManager,
|
||||
secretsManager SecretsManager,
|
||||
appMetrics telemetry.AppMetrics,
|
||||
ephemeralManager ephemeral.Manager,
|
||||
authManager auth.Manager,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
networkMapController network_map.Controller,
|
||||
) (*Server, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||
return int64(peersUpdateManager.CountStreams())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
syncLim := int32(defaultSyncLim)
|
||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
|
||||
} else {
|
||||
//nolint:gosec
|
||||
syncLim = int32(syncLimParsed)
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
config: config,
|
||||
secretsManager: secretsManager,
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
logBlockedPeers: logBlockedPeers,
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
networkMapController: networkMapController,
|
||||
|
||||
loginFilter: newLoginFilter(),
|
||||
|
||||
syncLim: syncLim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||
ip := ""
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if ok {
|
||||
ip = p.Addr.String()
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
|
||||
}()
|
||||
|
||||
// todo introduce something more meaningful with the key expiration/rotation
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
|
||||
}
|
||||
now := time.Now().Add(24 * time.Hour)
|
||||
secs := int64(now.Second())
|
||||
nanos := int32(now.Nanosecond())
|
||||
expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos}
|
||||
|
||||
return &proto.ServerKeyResponse{
|
||||
Key: s.wgKey.PublicKey().String(),
|
||||
ExpiresAt: expiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getRealIP(ctx context.Context) net.IP {
|
||||
if addr, ok := realip.FromContext(ctx); ok {
|
||||
return net.IP(addr.AsSlice())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
if s.syncSem.Load() >= s.syncLim {
|
||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||
}
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
syncReq := &proto.SyncRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, syncReq)
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
sRealIP := realIP.String()
|
||||
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
}
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
|
||||
}
|
||||
}
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequest()
|
||||
}
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
|
||||
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
s.syncSem.Add(-1)
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
start := time.Now()
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
||||
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||
|
||||
if syncReq.GetMeta() == nil {
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
|
||||
|
||||
s.ephemeralManager.OnPeerConnected(ctx, peer)
|
||||
|
||||
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
case update, open := <-updates:
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||
}
|
||||
|
||||
if !open {
|
||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
// condition when client <-> server connection has been terminated
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||
}
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
|
||||
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
||||
}
|
||||
|
||||
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
if s.authManager == nil {
|
||||
return "", status.Errorf(codes.Internal, "missing auth manager")
|
||||
}
|
||||
|
||||
userAuth, token, err := s.authManager.ValidateAndParseToken(ctx, jwtToken)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
}
|
||||
|
||||
if userAuth.AccountId != accountId {
|
||||
log.WithContext(ctx).Debugf("gRPC server sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||
userAuth.AccountId = accountId
|
||||
}
|
||||
|
||||
userAuth, err = s.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, token)
|
||||
if err != nil {
|
||||
return "", status.Error(codes.PermissionDenied, err.Error())
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncUserJWTGroups(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("gRPC server failed to sync user JWT groups: %s", err)
|
||||
}
|
||||
|
||||
return userAuth.UserId, nil
|
||||
}
|
||||
|
||||
func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// maps internal internalStatus.Error to gRPC status.Error
|
||||
func mapError(ctx context.Context, err error) error {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case internalStatus.PermissionDenied:
|
||||
return status.Error(codes.PermissionDenied, e.Message)
|
||||
case internalStatus.Unauthorized:
|
||||
return status.Error(codes.PermissionDenied, e.Message)
|
||||
case internalStatus.Unauthenticated:
|
||||
return status.Error(codes.PermissionDenied, e.Message)
|
||||
case internalStatus.PreconditionFailed:
|
||||
return status.Error(codes.FailedPrecondition, e.Message)
|
||||
case internalStatus.NotFound:
|
||||
return status.Error(codes.NotFound, e.Message)
|
||||
default:
|
||||
}
|
||||
}
|
||||
if errors.Is(err, internalStatus.ErrPeerAlreadyLoggedIn) {
|
||||
return status.Error(codes.PermissionDenied, internalStatus.ErrPeerAlreadyLoggedIn.Error())
|
||||
}
|
||||
log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
|
||||
return status.Errorf(codes.Internal, "failed handling request")
|
||||
}
|
||||
|
||||
func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
if meta == nil {
|
||||
return nbpeer.PeerSystemMeta{}
|
||||
}
|
||||
|
||||
osVersion := meta.GetOSVersion()
|
||||
if osVersion == "" {
|
||||
osVersion = meta.GetCore()
|
||||
}
|
||||
|
||||
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(meta.GetNetworkAddresses()))
|
||||
for _, addr := range meta.GetNetworkAddresses() {
|
||||
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
|
||||
continue
|
||||
}
|
||||
networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{
|
||||
NetIP: netAddr,
|
||||
Mac: addr.GetMac(),
|
||||
})
|
||||
}
|
||||
|
||||
files := make([]nbpeer.File, 0, len(meta.GetFiles()))
|
||||
for _, file := range meta.GetFiles() {
|
||||
files = append(files, nbpeer.File{
|
||||
Path: file.GetPath(),
|
||||
Exist: file.GetExist(),
|
||||
ProcessIsRunning: file.GetProcessIsRunning(),
|
||||
})
|
||||
}
|
||||
|
||||
return nbpeer.PeerSystemMeta{
|
||||
Hostname: meta.GetHostname(),
|
||||
GoOS: meta.GetGoOS(),
|
||||
Kernel: meta.GetKernel(),
|
||||
Platform: meta.GetPlatform(),
|
||||
OS: meta.GetOS(),
|
||||
OSVersion: osVersion,
|
||||
WtVersion: meta.GetNetbirdVersion(),
|
||||
UIVersion: meta.GetUiVersion(),
|
||||
KernelVersion: meta.GetKernelVersion(),
|
||||
NetworkAddresses: networkAddresses,
|
||||
SystemSerialNumber: meta.GetSysSerialNumber(),
|
||||
SystemProductName: meta.GetSysProductName(),
|
||||
SystemManufacturer: meta.GetSysManufacturer(),
|
||||
Environment: nbpeer.Environment{
|
||||
Cloud: meta.GetEnvironment().GetCloud(),
|
||||
Platform: meta.GetEnvironment().GetPlatform(),
|
||||
},
|
||||
Flags: nbpeer.Flags{
|
||||
RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
|
||||
ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
|
||||
DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
|
||||
DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
|
||||
DisableDNS: meta.GetFlags().GetDisableDNS(),
|
||||
DisableFirewall: meta.GetFlags().GetDisableFirewall(),
|
||||
BlockLANAccess: meta.GetFlags().GetBlockLANAccess(),
|
||||
BlockInbound: meta.GetFlags().GetBlockInbound(),
|
||||
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
|
||||
},
|
||||
Files: files,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
|
||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
|
||||
if err != nil {
|
||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
|
||||
}
|
||||
|
||||
return peerKey, nil
|
||||
}
|
||||
|
||||
// Login endpoint first checks whether peer is registered under any account
|
||||
// In case it is, the login is successful
|
||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||
// In case of the successful registration login is also successful
|
||||
func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
reqStart := time.Now()
|
||||
realIP := getRealIP(ctx)
|
||||
sRealIP := realIP.String()
|
||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||
|
||||
loginReq := &proto.LoginRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, loginReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
return nil, internalStatus.ErrPeerAlreadyLoggedIn
|
||||
}
|
||||
}
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
|
||||
accountID = "UNKNOWN"
|
||||
}
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
||||
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
took := time.Since(reqStart)
|
||||
if took > 7*time.Second {
|
||||
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
|
||||
}
|
||||
}()
|
||||
|
||||
if loginReq.GetMeta() == nil {
|
||||
msg := status.Errorf(codes.FailedPrecondition,
|
||||
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
log.WithContext(ctx).Warn(msg)
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
userID, err := s.processJwtToken(ctx, loginReq, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sshKey []byte
|
||||
if loginReq.GetPeerKeys() != nil {
|
||||
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
|
||||
WireGuardPubKey: peerKey.String(),
|
||||
SSHKey: string(sshKey),
|
||||
Meta: peerMeta,
|
||||
UserID: userID,
|
||||
SetupKey: loginReq.GetSetupKey(),
|
||||
ConnectionIP: realIP,
|
||||
ExtraDNSLabels: loginReq.GetDnsLabels(),
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
|
||||
|
||||
// if the login request contains setup key then it is a registration request
|
||||
if loginReq.GetSetupKey() != "" {
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
|
||||
}
|
||||
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
var relayToken *Token
|
||||
var err error
|
||||
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
|
||||
relayToken, err = s.secretsManager.GenerateRelayToken()
|
||||
if err != nil {
|
||||
log.Errorf("failed generating Relay token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed getting settings for peer %s: %s", peer.Key, err)
|
||||
return nil, status.Errorf(codes.Internal, "failed getting settings")
|
||||
}
|
||||
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if
|
||||
// the token is valid.
|
||||
//
|
||||
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
|
||||
// registered or if it uses a setup key to register.
|
||||
func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
userID := ""
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed validating JWT token sent from peer %s with error %v. "+
|
||||
"Trying again as it may be due to the IdP cache issue", peerKey.String(), err)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// IsHealthy indicates whether the service is healthy
|
||||
func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error {
|
||||
var err error
|
||||
|
||||
var turnToken *Token
|
||||
if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials {
|
||||
turnToken, err = s.secretsManager.GenerateTurnToken()
|
||||
if err != nil {
|
||||
log.Errorf("failed generating TURN token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var relayToken *Token
|
||||
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
|
||||
relayToken, err = s.secretsManager.GenerateRelayToken()
|
||||
if err != nil {
|
||||
log.Errorf("failed generating Relay token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
sendStart := time.Now()
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlow returns a device authorization flow information
|
||||
// This is used for initiating an Oauth 2 device authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
|
||||
}()
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) {
|
||||
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
|
||||
}
|
||||
|
||||
provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider)
|
||||
}
|
||||
|
||||
flowInfoResp := &proto.DeviceAuthorizationFlow{
|
||||
Provider: proto.DeviceAuthorizationFlowProvider(provider),
|
||||
ProviderConfig: &proto.ProviderConfig{
|
||||
ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID,
|
||||
ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret,
|
||||
Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain,
|
||||
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
|
||||
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
|
||||
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
|
||||
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
|
||||
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||
},
|
||||
}
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
|
||||
// This is used for initiating an Oauth 2 pkce authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
|
||||
}()
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
if s.config.PKCEAuthorizationFlow == nil {
|
||||
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
|
||||
}
|
||||
|
||||
initInfoFlow := &proto.PKCEAuthorizationFlow{
|
||||
ProviderConfig: &proto.ProviderConfig{
|
||||
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
|
||||
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
|
||||
ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret,
|
||||
TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint,
|
||||
AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint,
|
||||
Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope,
|
||||
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
|
||||
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||
DisablePromptLogin: s.config.PKCEAuthorizationFlow.ProviderConfig.DisablePromptLogin,
|
||||
LoginFlag: uint32(s.config.PKCEAuthorizationFlow.ProviderConfig.LoginFlag),
|
||||
},
|
||||
}
|
||||
|
||||
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
|
||||
// peer's under the same account of any updates.
|
||||
func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
realIP := getRealIP(ctx)
|
||||
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
syncMetaReq := &proto.SyncMetaRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, syncMetaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if syncMetaReq.GetMeta() == nil {
|
||||
msg := status.Errorf(codes.FailedPrecondition,
|
||||
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
log.WithContext(ctx).Warn(msg)
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
|
||||
if err != nil {
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
||||
start := time.Now()
|
||||
|
||||
empty := &proto.Empty{}
|
||||
peerKey, err := s.parseRequest(ctx, req, empty)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String())
|
||||
// TODO: consider idempotency
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.ID)
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, peer.AccountID)
|
||||
|
||||
userID := peer.UserID
|
||||
if userID == "" {
|
||||
userID = activity.SystemInitiator
|
||||
}
|
||||
|
||||
if err = s.accountManager.DeletePeer(ctx, peer.AccountID, peer.ID, userID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to logout peer %s: %v", peerKey.String(), err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// toProtocolChecks converts posture checks to protocol checks.
|
||||
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
|
||||
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
|
||||
for _, postureCheck := range postureChecks {
|
||||
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
|
||||
}
|
||||
|
||||
return protoChecks
|
||||
}
|
||||
|
||||
// toProtocolCheck converts a posture.Checks to a proto.Checks.
|
||||
func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
|
||||
protoCheck := &proto.Checks{}
|
||||
|
||||
if check := postureCheck.Checks.ProcessCheck; check != nil {
|
||||
for _, process := range check.Processes {
|
||||
if process.LinuxPath != "" {
|
||||
protoCheck.Files = append(protoCheck.Files, process.LinuxPath)
|
||||
}
|
||||
if process.MacPath != "" {
|
||||
protoCheck.Files = append(protoCheck.Files, process.MacPath)
|
||||
}
|
||||
if process.WindowsPath != "" {
|
||||
protoCheck.Files = append(protoCheck.Files, process.WindowsPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return protoCheck
|
||||
}
|
||||
106
management/internals/shared/grpc/server_test.go
Normal file
106
management/internals/shared/grpc/server_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testingClientKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputFlow *config.DeviceAuthorizationFlow
|
||||
expectedFlow *mgmtProto.DeviceAuthorizationFlow
|
||||
expectedErrFunc require.ErrorAssertionFunc
|
||||
expectedErrMSG string
|
||||
expectedComparisonFunc require.ComparisonAssertionFunc
|
||||
expectedComparisonMSG string
|
||||
}{
|
||||
{
|
||||
name: "Testing No Device Flow Config",
|
||||
inputFlow: nil,
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Invalid Device Flow Provider Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "NoNe",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Full Device Flow Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "hosted",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
|
||||
Provider: 0,
|
||||
ProviderConfig: &mgmtProto.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.NoError,
|
||||
expectedErrMSG: "should not return error",
|
||||
expectedComparisonFunc: require.Equal,
|
||||
expectedComparisonMSG: "should match",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
mgmtServer := &Server{
|
||||
wgKey: testingServerKey,
|
||||
config: &config.Config{
|
||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
||||
},
|
||||
}
|
||||
|
||||
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
||||
|
||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
|
||||
require.NoError(t, err, "should be able to encrypt message")
|
||||
|
||||
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
||||
context.TODO(),
|
||||
&mgmtProto.EncryptedMessage{
|
||||
WgPubKey: testingClientKey.PublicKey().String(),
|
||||
Body: encryptedMSG,
|
||||
},
|
||||
)
|
||||
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
|
||||
if testCase.expectedComparisonFunc != nil {
|
||||
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
||||
|
||||
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
||||
require.NoError(t, err, "should be able to decrypt")
|
||||
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
271
management/internals/shared/grpc/token_mgr.go
Normal file
271
management/internals/shared/grpc/token_mgr.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
|
||||
)
|
||||
|
||||
const defaultDuration = 12 * time.Hour
|
||||
|
||||
// SecretsManager used to manage TURN and relay secrets
|
||||
type SecretsManager interface {
|
||||
GenerateTurnToken() (*Token, error)
|
||||
GenerateRelayToken() (*Token, error)
|
||||
SetupRefresh(ctx context.Context, accountID, peerKey string)
|
||||
CancelRefresh(peerKey string)
|
||||
}
|
||||
|
||||
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
||||
type TimeBasedAuthSecretsManager struct {
|
||||
mux sync.Mutex
|
||||
turnCfg *nbconfig.TURNConfig
|
||||
relayCfg *nbconfig.Relay
|
||||
turnHmacToken *auth.TimedHMAC
|
||||
relayHmacToken *authv2.Generator
|
||||
updateManager network_map.PeersUpdateManager
|
||||
settingsManager settings.Manager
|
||||
groupsManager groups.Manager
|
||||
turnCancelMap map[string]chan struct{}
|
||||
relayCancelMap map[string]chan struct{}
|
||||
}
|
||||
|
||||
type Token auth.Token
|
||||
|
||||
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
|
||||
mgr := &TimeBasedAuthSecretsManager{
|
||||
updateManager: updateManager,
|
||||
turnCfg: turnCfg,
|
||||
relayCfg: relayCfg,
|
||||
turnCancelMap: make(map[string]chan struct{}),
|
||||
relayCancelMap: make(map[string]chan struct{}),
|
||||
settingsManager: settingsManager,
|
||||
groupsManager: groupsManager,
|
||||
}
|
||||
|
||||
if turnCfg != nil {
|
||||
duration := turnCfg.CredentialsTTL.Duration
|
||||
if turnCfg.CredentialsTTL.Duration <= 0 {
|
||||
log.Warnf("TURN credentials TTL is not set or invalid, using default value %s", defaultDuration)
|
||||
duration = defaultDuration
|
||||
}
|
||||
mgr.turnHmacToken = auth.NewTimedHMAC(turnCfg.Secret, duration)
|
||||
}
|
||||
|
||||
if relayCfg != nil {
|
||||
duration := relayCfg.CredentialsTTL.Duration
|
||||
if relayCfg.CredentialsTTL.Duration <= 0 {
|
||||
log.Warnf("Relay credentials TTL is not set or invalid, using default value %s", defaultDuration)
|
||||
duration = defaultDuration
|
||||
}
|
||||
|
||||
hashedSecret := sha256.Sum256([]byte(relayCfg.Secret))
|
||||
var err error
|
||||
if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil {
|
||||
log.Errorf("failed to create relay token generator: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return mgr
|
||||
}
|
||||
|
||||
// GenerateTurnToken generates new time-based secret credentials for TURN
|
||||
func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
|
||||
if m.turnHmacToken == nil {
|
||||
return nil, fmt.Errorf("TURN configuration is not set")
|
||||
}
|
||||
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate TURN token: %s", err)
|
||||
}
|
||||
return (*Token)(turnToken), nil
|
||||
}
|
||||
|
||||
// GenerateRelayToken generates new time-based secret credentials for relay
|
||||
func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
|
||||
if m.relayHmacToken == nil {
|
||||
return nil, fmt.Errorf("relay configuration is not set")
|
||||
}
|
||||
relayToken, err := m.relayHmacToken.GenerateToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate relay token: %s", err)
|
||||
}
|
||||
|
||||
return &Token{
|
||||
Payload: string(relayToken.Payload),
|
||||
Signature: base64.StdEncoding.EncodeToString(relayToken.Signature),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
|
||||
if channel, ok := m.turnCancelMap[peerID]; ok {
|
||||
close(channel)
|
||||
delete(m.turnCancelMap, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) cancelRelay(peerID string) {
|
||||
if channel, ok := m.relayCancelMap[peerID]; ok {
|
||||
close(channel)
|
||||
delete(m.relayCancelMap, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
// CancelRefresh cancels scheduled peer credentials refresh
|
||||
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.cancelTURN(peerID)
|
||||
m.cancelRelay(peerID)
|
||||
}
|
||||
|
||||
// SetupRefresh starts peer credentials refresh
|
||||
func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountID, peerID string) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
m.cancelTURN(peerID)
|
||||
m.cancelRelay(peerID)
|
||||
|
||||
if m.turnCfg != nil && m.turnCfg.TimeBasedCredentials {
|
||||
turnCancel := make(chan struct{}, 1)
|
||||
m.turnCancelMap[peerID] = turnCancel
|
||||
go m.refreshTURNTokens(ctx, accountID, peerID, turnCancel)
|
||||
log.WithContext(ctx).Debugf("starting TURN refresh for %s", peerID)
|
||||
}
|
||||
|
||||
if m.relayCfg != nil {
|
||||
relayCancel := make(chan struct{}, 1)
|
||||
m.relayCancelMap[peerID] = relayCancel
|
||||
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
|
||||
log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, accountID, peerID string, cancel chan struct{}) {
|
||||
ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cancel:
|
||||
log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, accountID, peerID string, cancel chan struct{}) {
|
||||
ticker := time.NewTicker(m.relayCfg.CredentialsTTL.Duration / 4 * 3)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cancel:
|
||||
log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.pushNewRelayTokens(ctx, accountID, peerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to generate token for peer '%s': %s", peerID, err)
|
||||
return
|
||||
}
|
||||
|
||||
var turns []*proto.ProtectedHostConfig
|
||||
for _, host := range m.turnCfg.Turns {
|
||||
turn := &proto.ProtectedHostConfig{
|
||||
HostConfig: &proto.HostConfig{
|
||||
Uri: host.URI,
|
||||
Protocol: ToResponseProto(host.Proto),
|
||||
},
|
||||
User: turnToken.Payload,
|
||||
Password: turnToken.Signature,
|
||||
}
|
||||
turns = append(turns, turn)
|
||||
}
|
||||
|
||||
update := &proto.SyncResponse{
|
||||
NetbirdConfig: &proto.NetbirdConfig{
|
||||
Turns: turns,
|
||||
},
|
||||
}
|
||||
|
||||
// workaround for the case when client is unable to handle turn and relay updates at different time
|
||||
if m.relayCfg != nil {
|
||||
token, err := m.GenerateRelayToken()
|
||||
if err == nil {
|
||||
update.NetbirdConfig.Relay = &proto.RelayConfig{
|
||||
Urls: m.relayCfg.Addresses,
|
||||
TokenPayload: token.Payload,
|
||||
TokenSignature: token.Signature,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||
relayToken, err := m.relayHmacToken.GenerateToken()
|
||||
if err != nil {
|
||||
log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
|
||||
return
|
||||
}
|
||||
|
||||
update := &proto.SyncResponse{
|
||||
NetbirdConfig: &proto.NetbirdConfig{
|
||||
Relay: &proto.RelayConfig{
|
||||
Urls: m.relayCfg.Addresses,
|
||||
TokenPayload: string(relayToken.Payload),
|
||||
TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature),
|
||||
},
|
||||
// omit Turns to avoid updates there
|
||||
},
|
||||
}
|
||||
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
|
||||
}
|
||||
|
||||
peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
|
||||
}
|
||||
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
|
||||
update.NetbirdConfig = extendedConfig
|
||||
}
|
||||
247
management/internals/shared/grpc/token_mgr_test.go
Normal file
247
management/internals/shared/grpc/token_mgr_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"hash"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var TurnTestHost = &config.Host{
|
||||
Proto: config.UDP,
|
||||
URI: "turn:turn.netbird.io:77777",
|
||||
Username: "username",
|
||||
Password: "",
|
||||
}
|
||||
|
||||
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
secret := "some_secret"
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
|
||||
rc := &config.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
|
||||
turnCredentials, err := tested.GenerateTurnToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
if turnCredentials.Payload == "" {
|
||||
t.Errorf("expected generated TURN username not to be empty, got empty")
|
||||
}
|
||||
if turnCredentials.Signature == "" {
|
||||
t.Errorf("expected generated TURN password not to be empty, got empty")
|
||||
}
|
||||
|
||||
validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret))
|
||||
|
||||
relayCredentials, err := tested.GenerateRelayToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
if relayCredentials.Payload == "" {
|
||||
t.Errorf("expected generated relay payload not to be empty, got empty")
|
||||
}
|
||||
if relayCredentials.Signature == "" {
|
||||
t.Errorf("expected generated relay signature not to be empty, got empty")
|
||||
}
|
||||
|
||||
hashedSecret := sha256.Sum256([]byte(secret))
|
||||
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
|
||||
}
|
||||
|
||||
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
ttl := util.Duration{Duration: 2 * time.Second}
|
||||
secret := "some_secret"
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
peer := "some_peer"
|
||||
updateChannel := peersManager.CreateChannel(context.Background(), peer)
|
||||
|
||||
rc := &config.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tested.SetupRefresh(ctx, "someAccountID", peer)
|
||||
|
||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
||||
t.Errorf("expecting peer to be present in the turn cancel map, got not present")
|
||||
}
|
||||
|
||||
if _, ok := tested.relayCancelMap[peer]; !ok {
|
||||
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
|
||||
}
|
||||
|
||||
var updates []*network_map.UpdateMessage
|
||||
|
||||
loop:
|
||||
for timeout := time.After(5 * time.Second); ; {
|
||||
select {
|
||||
case update := <-updateChannel:
|
||||
updates = append(updates, update)
|
||||
case <-timeout:
|
||||
break loop
|
||||
}
|
||||
|
||||
if len(updates) >= 2 {
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
if len(updates) < 2 {
|
||||
t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates))
|
||||
}
|
||||
|
||||
var turnUpdates, relayUpdates int
|
||||
var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig
|
||||
var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig
|
||||
|
||||
for _, update := range updates {
|
||||
if turns := update.Update.GetNetbirdConfig().GetTurns(); len(turns) > 0 {
|
||||
turnUpdates++
|
||||
if turnUpdates == 1 {
|
||||
firstTurnUpdate = turns[0]
|
||||
} else {
|
||||
secondTurnUpdate = turns[0]
|
||||
}
|
||||
}
|
||||
if relay := update.Update.GetNetbirdConfig().GetRelay(); relay != nil {
|
||||
// avoid updating on turn updates since they also send relay credentials
|
||||
if update.Update.GetNetbirdConfig().GetTurns() == nil {
|
||||
relayUpdates++
|
||||
if relayUpdates == 1 {
|
||||
firstRelayUpdate = relay
|
||||
} else {
|
||||
secondRelayUpdate = relay
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if turnUpdates < 1 {
|
||||
t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates)
|
||||
}
|
||||
if relayUpdates < 1 {
|
||||
t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates)
|
||||
}
|
||||
|
||||
if firstTurnUpdate != nil && secondTurnUpdate != nil {
|
||||
if firstTurnUpdate.Password == secondTurnUpdate.Password {
|
||||
t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password)
|
||||
}
|
||||
}
|
||||
|
||||
if firstRelayUpdate != nil && secondRelayUpdate != nil {
|
||||
if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature {
|
||||
t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
secret := "some_secret"
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
peer := "some_peer"
|
||||
|
||||
rc := &config.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
|
||||
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
||||
t.Errorf("expecting peer to be present in turn cancel map, got not present")
|
||||
}
|
||||
if _, ok := tested.relayCancelMap[peer]; !ok {
|
||||
t.Errorf("expecting peer to be present in relay cancel map, got not present")
|
||||
}
|
||||
|
||||
tested.CancelRefresh(peer)
|
||||
if _, ok := tested.turnCancelMap[peer]; ok {
|
||||
t.Errorf("expecting peer to be not present in turn cancel map, got present")
|
||||
}
|
||||
if _, ok := tested.relayCancelMap[peer]; ok {
|
||||
t.Errorf("expecting peer to be not present in relay cancel map, got present")
|
||||
}
|
||||
}
|
||||
|
||||
func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) {
|
||||
t.Helper()
|
||||
mac := hmac.New(algo, key)
|
||||
|
||||
_, err := mac.Write([]byte(username))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedMAC := mac.Sum(nil)
|
||||
decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
equal := hmac.Equal(decodedMAC, expectedMAC)
|
||||
|
||||
if !equal {
|
||||
t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user