Compare commits

...

13 Commits

Author SHA1 Message Date
aliamerj
b7cd2ee252 just testing 2025-09-01 14:05:19 +03:00
Ali Amer
3f6d95552f implement remote debug api (#4418)
fix lint

clean up

fix MarkPendingJobsAsFailed

apply feedbacks 1

fix typo

change api and apply new schema

fix lint

fix api object

clean switch case

apply feedback 2

fix error handle in create job

get rid of any/interface type in job database

fix sonar issue

use RawJson for both parameters and results

running go mod tidy

update package

fix 1

update codegen

fix code-gen

fix snyk

fix snyk hopefully
2025-08-29 18:00:40 +03:00
Ali Amer
d4ac7f8df9 [management/client] create job channel between management and client (#4367)
* new bi-directional stream for jobs

* create bidirectional job channel to send requests from the server and receive responses from the client

* fix tests

* fix lint and close bug

* fix lint

* clean up & fix close of closed channel

* add nolint:staticcheck

* remove some redundant code from the job channel PR since this one is a cleaner rewrite

* cleanup removes a pending job safely

* change proto

* rename to jobRequest

* apply feedback 1

* apply feedback 2

* fix typo

* apply feedback 3

* apply last feedback
2025-08-28 16:49:09 +03:00
dependabot[bot]
9685411246 [misc] Bump golang.org/x/oauth2 from 0.24.0 to 0.27.0 (#4176)
Bumps [golang.org/x/oauth2](https://github.com/golang/oauth2) from 0.24.0 to 0.27.0
2025-08-19 16:26:46 +03:00
hakansa
d00a226556 [management] Add CreatedAt field to Peer and PeerBatch models (#4371)
[management] Add CreatedAt field to Peer and PeerBatch models (#4371)
2025-08-19 16:02:11 +03:00
Pascal Fischer
5d361b5421 [management] add nil handling for route domains (#4366) 2025-08-19 11:35:03 +02:00
dependabot[bot]
a889c4108b [misc] Bump github.com/containerd/containerd from 1.7.16 to 1.7.27 (#3527)
Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.16 to 1.7.27
2025-08-18 21:57:21 +03:00
Zoltan Papp
12cad854b2 [client] Fix/ice handshake (#4281)
In this PR, speed up the GRPC message processing, force the recreation of the ICE agent when getting a new, remote offer (do not wait for local STUN timeout).
2025-08-18 20:09:50 +02:00
Pascal Fischer
6a3846a8b7 [management] Remove save account calls (#4349) 2025-08-18 12:37:20 +02:00
Viktor Liu
7cd5dcae59 [client] Fix rule order for deny rules in peer ACLs (#4147) 2025-08-18 11:17:00 +02:00
Pascal Fischer
0e62325d46 [management] fail on geo location init failure (#4362) 2025-08-18 10:53:55 +02:00
Pascal Fischer
b3056d0937 [management] Use DI containers for server bootstrapping (#4343) 2025-08-15 17:14:48 +02:00
Zoltan Papp
ab853ac2a5 [server] Add MySQL initialization script and update Docker configuration (#4345) 2025-08-14 17:53:59 +02:00
72 changed files with 4873 additions and 1821 deletions

View File

@@ -83,6 +83,15 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup MySQL privileges
if: matrix.store == 'mysql'
run: |
sleep 10
mysql -h 127.0.0.1 -u root -pmysqlroot -e "
GRANT SYSTEM_VARIABLES_ADMIN ON *.* TO 'netbird'@'%';
FLUSH PRIVILEGES;
"
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -35,7 +36,7 @@ import (
func startTestingServices(t *testing.T) string {
t.Helper()
config := &types.Config{}
config := &config.Config{}
_, err := util.ReadJson("../testdata/management.json", config)
if err != nil {
t.Fatal(err)
@@ -70,7 +71,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis
}
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *config.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
@@ -85,6 +86,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
jobManager := mgmt.NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
@@ -105,13 +107,13 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}

View File

@@ -85,7 +85,7 @@ func (m *aclManager) AddPeerFiltering(
) ([]firewall.Rule, error) {
chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort)
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs)
@@ -135,7 +135,14 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
// Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// Insert at the beginning of the chain (position 1)
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
} else {
err = m.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
return nil, err
}
@@ -388,17 +395,25 @@ func actionToStr(action firewall.Action) string {
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
switch {
case ipsetName == "":
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
if ipsetName == "" {
return ""
}
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"
}
switch {
case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport"
return ipsetName + "-sport-dport" + actionSuffix
case sPort != nil:
return ipsetName + "-sport"
return ipsetName + "-sport" + actionSuffix
case dPort != nil:
return ipsetName + "-dport"
return ipsetName + "-dport" + actionSuffix
default:
return ipsetName
return ipsetName + actionSuffix
}
}

View File

@@ -3,6 +3,7 @@ package iptables
import (
"fmt"
"net/netip"
"strings"
"testing"
"time"
@@ -15,7 +16,7 @@ import (
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
return "wg-test"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
@@ -109,10 +110,84 @@ func TestIptablesManager(t *testing.T) {
})
}
func TestIptablesManagerDenyRules(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
manager, err := Create(ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
defer func() {
err := manager.Close(nil)
require.NoError(t, err)
}()
t.Run("add deny rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{22}}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
require.NoError(t, err, "failed to add deny rule")
require.NotEmpty(t, rule, "deny rule should not be empty")
// Verify the rule was added by checking iptables
for _, r := range rule {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("deny rule precedence test", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.4")
port := &fw.Port{Values: []uint16{80}}
// Add accept rule first
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for same IP/port - this should take precedence
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
// Inspect the actual iptables rules to verify deny rule comes before accept rule
rules, err := ipv4Client.List("filter", chainNameInputRules)
require.NoError(t, err, "failed to list iptables rules")
// Debug: print all rules
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
for i, rule := range rules {
t.Logf(" [%d] %s", i, rule)
}
var denyRuleIndex, acceptRuleIndex int = -1, -1
for i, rule := range rules {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
denyRuleIndex = i
}
}
if strings.Contains(rule, "ACCEPT") {
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
acceptRuleIndex = i
}
}
}
require.NotEqual(t, -1, denyRuleIndex, "deny rule should exist in iptables")
require.NotEqual(t, -1, acceptRuleIndex, "accept rule should exist in iptables")
require.Less(t, denyRuleIndex, acceptRuleIndex,
"deny rule should come before accept rule in iptables chain (deny at index %d, accept at index %d)",
denyRuleIndex, acceptRuleIndex)
})
}
func TestIptablesManagerIPSet(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
return "wg-test"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
@@ -176,7 +251,7 @@ func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName strin
func TestIptablesCreatePerformance(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
return "wg-test"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{

View File

@@ -341,30 +341,38 @@ func (m *AclManager) addIOFiltering(
userData := []byte(ruleId)
chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{
rule := &nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: mainExpressions,
UserData: userData,
})
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var nftRule *nftables.Rule
if action == firewall.ActionDrop {
nftRule = m.rConn.InsertRule(rule)
} else {
nftRule = m.rConn.AddRule(rule)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
rule := &Rule{
ruleStruct := &Rule{
nftRule: nftRule,
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = rule
m.rules[ruleId] = ruleStruct
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return rule, nil
return ruleStruct, nil
}
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {

View File

@@ -2,6 +2,7 @@ package nftables
import (
"bytes"
"encoding/binary"
"fmt"
"net/netip"
"os/exec"
@@ -20,7 +21,7 @@ import (
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
return "wg-test"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
@@ -103,9 +104,8 @@ func TestNftablesManager(t *testing.T) {
Kind: expr.VerdictAccept,
},
}
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
expectedExprs2 := []expr.Any{
// Since DROP rules are inserted at position 0, the DROP rule comes first
expectedDropExprs := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
@@ -141,7 +141,12 @@ func TestNftablesManager(t *testing.T) {
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
// Compare DROP rule at position 0 (inserted first due to InsertRule)
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedDropExprs)
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
for _, r := range rule {
err = manager.DeletePeerRule(r)
@@ -160,10 +165,90 @@ func TestNftablesManager(t *testing.T) {
require.NoError(t, err, "failed to reset")
}
func TestNftablesManagerRuleOrder(t *testing.T) {
// This test verifies rule insertion order in nftables peer ACLs
// We add accept rule first, then deny rule to test ordering behavior
manager, err := Create(ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
defer func() {
err = manager.Close(nil)
require.NoError(t, err)
}()
ip := netip.MustParseAddr("100.96.0.2").Unmap()
testClient := &nftables.Conn{}
// Add accept rule first
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for the same traffic
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex int = -1, -1
for i, rule := range rules {
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
hasPort80 := false
var action string
for _, e := range rule.Exprs {
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
if lookup.SetName == "accept-http" {
hasAcceptHTTPSet = true
} else if lookup.SetName == "deny-http" {
hasDenyHTTPSet = true
}
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
hasPort80 = true
}
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
if verdict.Kind == expr.VerdictAccept {
action = "ACCEPT"
} else if verdict.Kind == expr.VerdictDrop {
action = "DROP"
}
}
}
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
acceptRuleIndex = i
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
denyRuleIndex = i
}
}
require.NotEqual(t, -1, acceptRuleIndex, "accept rule should exist in nftables")
require.NotEqual(t, -1, denyRuleIndex, "deny rule should exist in nftables")
require.Less(t, denyRuleIndex, acceptRuleIndex,
"deny rule should come before accept rule in nftables chain (deny at index %d, accept at index %d)",
denyRuleIndex, acceptRuleIndex)
}
func TestNFtablesCreatePerformance(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
return "wg-test"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{

View File

@@ -18,6 +18,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {

View File

@@ -27,6 +27,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {

View File

@@ -70,14 +70,13 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager
type Manager struct {
// outgoingRules is used for hooks only
outgoingRules map[netip.Addr]RuleSet
// incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
outgoingRules map[netip.Addr]RuleSet
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
mutex sync.RWMutex
@@ -186,6 +185,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
},
nativeFirewall: nativeFirewall,
outgoingRules: make(map[netip.Addr]RuleSet),
incomingDenyRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
@@ -417,10 +417,17 @@ func (m *Manager) AddPeerFiltering(
}
m.mutex.Lock()
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(RuleSet)
var targetMap map[netip.Addr]RuleSet
if r.drop {
targetMap = m.incomingDenyRules
} else {
targetMap = m.incomingRules
}
m.incomingRules[r.ip][r.id] = r
if _, ok := targetMap[r.ip]; !ok {
targetMap[r.ip] = make(RuleSet)
}
targetMap[r.ip][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
}
@@ -507,10 +514,24 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
var sourceMap map[netip.Addr]RuleSet
if r.drop {
sourceMap = m.incomingDenyRules
} else {
sourceMap = m.incomingRules
}
if ruleset, ok := sourceMap[r.ip]; ok {
if _, exists := ruleset[r.id]; !exists {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(ruleset, r.id)
if len(ruleset) == 0 {
delete(sourceMap, r.ip)
}
} else {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip], r.id)
return nil
}
@@ -572,7 +593,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// FilterOutBound filters outgoing packets
// FilterOutbound filters outgoing packets
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
return m.filterOutbound(packetData, size)
}
@@ -761,7 +782,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked {
_, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
@@ -971,26 +992,28 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded
}
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
if m.isSpecialICMP(d) {
return nil, false
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter
}
// Default policy: DROP ALL
return nil, true
}
@@ -1013,6 +1036,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue
@@ -1045,6 +1069,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
}
return nil, false, false
}
@@ -1116,6 +1141,7 @@ func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook fu
m.mutex.Lock()
if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
@@ -1136,6 +1162,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
@@ -1144,6 +1171,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
}
}
}
// Check outgoing hooks
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {

View File

@@ -458,6 +458,31 @@ func TestPeerACLFiltering(t *testing.T) {
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Peer ACL - Drop rule should override accept all rule",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 22,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{22}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Peer ACL - Drop all traffic from specific IP",
srcIP: "100.10.0.99",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.99",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
}
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
@@ -468,13 +493,11 @@ func TestPeerACLFiltering(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule to test drop rule
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
// add general accept rule for the same IP to test drop rule precedence
rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP("0.0.0.0"),
net.ParseIP(tc.ruleIP),
fw.ProtocolALL,
nil,
nil,

View File

@@ -136,9 +136,22 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
// Check rules exist in appropriate maps
for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
peerRule, ok := r.(*PeerRule)
if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule exists in deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if !found {
t.Errorf("rule2 is not in the expected rules map")
}
}
@@ -150,9 +163,22 @@ func TestManagerDeleteRule(t *testing.T) {
}
}
// Check rules are removed from appropriate maps
for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
peerRule, ok := r.(*PeerRule)
if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule is removed from deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if found {
t.Errorf("rule2 should be removed from the rules map")
}
}
}
@@ -196,16 +222,17 @@ func TestAddUDPPacketHook(t *testing.T) {
var addedRule PeerRule
if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
if len(manager.outgoingRules) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
@@ -261,8 +288,8 @@ func TestManagerReset(t *testing.T) {
return
}
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
t.Errorf("rules is not empty")
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
t.Errorf("rules are not empty")
}
}

View File

@@ -314,7 +314,7 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
ruleId, blocked := m.peerACLsBlock(srcIP, d, packetData)
strRuleId := "<no id>"
if ruleId != nil {

View File

@@ -22,6 +22,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/proto"
nberrors "github.com/netbirdio/netbird/client/errors"
@@ -53,16 +55,19 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/upload-server/types"
"github.com/netbirdio/netbird/util"
"google.golang.org/grpc/status"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -125,6 +130,8 @@ type EngineConfig struct {
BlockInbound bool
LazyConnectionEnabled bool
peerDaemonAddr string
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -254,6 +261,7 @@ func NewEngine(
}
engine.stateManager = statemanager.New(path)
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
}
@@ -460,6 +468,7 @@ func (e *Engine) Start() error {
e.receiveSignalEvents()
e.receiveManagementEvents()
e.receiveJobEvents()
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
@@ -885,6 +894,101 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
func (e *Engine) getPeerClient(addr string) (*grpc.ClientConn, error) {
conn, err := grpc.NewClient(
strings.TrimPrefix(addr, "tcp://"),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
return conn, nil
}
func (e *Engine) receiveJobEvents() {
go func() {
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
// Simple test handler — replace with real logic
log.Infof("Received job request: %+v", msg)
// TODO: trigger local debug bundle or other job
return &mgmProto.JobResponse{
ID: msg.ID,
WorkloadResults: &mgmProto.JobResponse_Bundle{
Bundle: &mgmProto.BundleResult{
UploadKey: "upload-key",
},
},
}
})
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}
log.Debugf("stopped receiving jobs from Management Service")
}()
log.Debugf("connecting to Management Service jobs stream")
}
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (string, error) {
conn, err := e.getPeerClient("unix:///var/run/netbird.sock")
if err != nil {
return "", err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
}()
statusOutput, err := e.getStatusOutput(params.Anonymize)
if err != nil {
return "", err
}
request := &cProto.DebugBundleRequest{
Anonymize: params.Anonymize,
SystemInfo: true,
Status: statusOutput,
LogFileCount: uint32(params.LogFileCount),
UploadURL: types.DefaultBundleURL,
}
service := cProto.NewDaemonServiceClient(conn)
resp, err := service.DebugBundle(e.clientCtx, request)
if err != nil {
return "", fmt.Errorf("failed to bundle debug: " + status.Convert(err).Message())
}
if resp.GetUploadFailureReason() != "" {
return "", fmt.Errorf("upload failed: " + resp.GetUploadFailureReason())
}
return resp.GetUploadedKey(), nil
}
func (e *Engine) getStatusOutput(anon bool) (string, error) {
conn, err := e.getPeerClient("unix:///var/run/netbird.sock")
if err != nil {
return "", err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
}()
statusResp, err := cProto.NewDaemonServiceClient(conn).Status(e.clientCtx, &cProto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
if err != nil {
return "", fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
return nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
), nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
@@ -1330,52 +1434,17 @@ func (e *Engine) receiveSignalEvents() {
}
switch msg.GetBody().Type {
case sProto.Body_OFFER:
remoteCred, err := signal.UnMarshalCredential(msg)
case sProto.Body_OFFER, sProto.Body_ANSWER:
offerAnswer, err := convertToOfferAnswer(msg)
if err != nil {
return err
}
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
if msg.Body.Type == sProto.Body_OFFER {
conn.OnRemoteOffer(*offerAnswer)
} else {
conn.OnRemoteAnswer(*offerAnswer)
}
conn.OnRemoteOffer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return err
}
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
conn.OnRemoteAnswer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
if err != nil {
@@ -2073,3 +2142,44 @@ func createFile(path string) error {
}
return file.Close()
}
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return nil, err
}
var (
rosenpassPubKey []byte
rosenpassAddr string
)
if cfg := msg.GetBody().GetRosenpassConfig(); cfg != nil {
rosenpassPubKey = cfg.GetRosenpassPubKey()
rosenpassAddr = cfg.GetRosenpassServerAddr()
}
// Handle optional SessionID
var sessionID *peer.ICESessionID
if sessionBytes := msg.GetBody().GetSessionId(); sessionBytes != nil {
if id, err := peer.ICESessionIDFromBytes(sessionBytes); err != nil {
log.Warnf("Invalid session ID in message: %v", err)
sessionID = nil // Set to nil if conversion fails
} else {
sessionID = &id
}
}
offerAnswer := peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
SessionID: sessionID,
}
return &offerAnswer, nil
}

View File

@@ -27,6 +27,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface"
@@ -44,8 +45,6 @@ import (
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -55,8 +54,10 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
@@ -1514,15 +1515,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &types.Config{
Stuns: []*types.Host{},
TURNConfig: &types.TURNConfig{},
Relay: &types.Relay{
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &types.Host{
Signal: &config.Host{
Proto: "http",
URI: "localhost:10000",
},
@@ -1543,6 +1544,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
jobManager := server.NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
@@ -1567,13 +1569,13 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err
}

View File

@@ -24,8 +24,8 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
@@ -200,19 +200,11 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx)
conn.dumpState.SendOffer()
if err := conn.handshaker.sendOffer(); err != nil {
conn.Log.Errorf("failed to send initial offer: %v", err)
}
conn.wg.Add(1)
go func() {
conn.guard.Start(conn.ctx, conn.onGuardEvent)
conn.wg.Done()
}()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
conn.opened = true
return nil
@@ -274,10 +266,10 @@ func (conn *Conn) Close(signalToRemote bool) {
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
conn.dumpState.RemoteAnswer()
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer)
conn.handshaker.OnRemoteAnswer(answer)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -296,10 +288,10 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler
}
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
conn.dumpState.RemoteOffer()
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer)
conn.handshaker.OnRemoteOffer(offer)
}
// WgConfig returns the WireGuard config
@@ -548,7 +540,6 @@ func (conn *Conn) onRelayDisconnected() {
}
func (conn *Conn) onGuardEvent() {
conn.Log.Debugf("send offer to peer")
conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil {
conn.Log.Errorf("failed to send offer: %v", err)
@@ -672,7 +663,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
if conn.statusICE.Get() == worker.StatusDisconnected {
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}

View File

@@ -1,9 +1,9 @@
package peer
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
@@ -79,31 +79,30 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return
}
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
<-conn.handshaker.remoteOffersCh
wg.Done()
}()
onNewOffeChan := make(chan struct{})
go func() {
for {
accepted := conn.OnRemoteOffer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
if accepted {
wg.Done()
return
}
}
}()
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{}
})
wg.Wait()
conn.OnRemoteOffer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
}
}
func TestConn_OnRemoteAnswer(t *testing.T) {
@@ -119,31 +118,29 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return
}
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
<-conn.handshaker.remoteAnswerCh
wg.Done()
}()
onNewOffeChan := make(chan struct{})
go func() {
for {
accepted := conn.OnRemoteAnswer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
if accepted {
wg.Done()
return
}
}
}()
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{}
})
wg.Wait()
conn.OnRemoteAnswer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
}
}
func TestConn_presharedKey(t *testing.T) {

View File

@@ -19,7 +19,6 @@ type isConnectedFunc func() bool
// - Relayed connection disconnected
// - ICE candidate changes
type Guard struct {
Reconnect chan struct{}
log *log.Entry
isConnectedOnAllWay isConnectedFunc
timeout time.Duration
@@ -30,7 +29,6 @@ type Guard struct {
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{
Reconnect: make(chan struct{}, 1),
log: log,
isConnectedOnAllWay: isConnectedFn,
timeout: timeout,
@@ -41,6 +39,7 @@ func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Durati
}
func (g *Guard) Start(ctx context.Context, eventCallback func()) {
g.log.Infof("starting guard for reconnection with MaxInterval: %s", g.timeout)
g.reconnectLoopWithRetry(ctx, eventCallback)
}
@@ -61,17 +60,14 @@ func (g *Guard) SetICEConnDisconnected() {
// reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.prepareExponentTicker(ctx)
ticker := g.initialTicker(ctx)
defer ticker.Stop()
tickerChannel := ticker.C
g.log.Infof("start reconnect loop...")
for {
select {
case t := <-tickerChannel:
@@ -85,7 +81,6 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
@@ -111,6 +106,20 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
}
}
// initialTicker give chance to the peer to establish the initial connection.
func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 3 * time.Second,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
return backoff.NewTicker(bo)
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
@@ -126,13 +135,3 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
return ticker
}
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(3 * time.Second):
}
}

View File

@@ -39,6 +39,15 @@ type OfferAnswer struct {
// relay server address
RelaySrvAddress string
// SessionID is the unique identifier of the session, used to discard old messages
SessionID *ICESessionID
}
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
}
type Handshaker struct {
@@ -74,21 +83,25 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
func (h *Handshaker) Listen(ctx context.Context) {
for {
h.log.Info("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx)
if err != nil {
var connectionClosedError *ConnectionClosedError
if errors.As(err, &connectionClosedError) {
h.log.Info("exit from handshaker")
return
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
if err := h.sendAnswer(); err != nil {
h.log.Errorf("failed to send remote offer confirmation: %s", err)
continue
}
h.log.Errorf("failed to received remote offer confirmation: %s", err)
continue
}
h.log.Infof("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
for _, listener := range h.onNewOfferListeners {
go listener(remoteOfferAnswer)
for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer)
}
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer)
}
case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers")
return
}
}
}
@@ -101,43 +114,27 @@ func (h *Handshaker) SendOffer() error {
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) {
select {
case h.remoteOffersCh <- offer:
return true
return
default:
h.log.Warnf("OnRemoteOffer skipping message because is not ready")
h.log.Warnf("skipping remote offer message because receiver not ready")
// connection might not be ready yet to receive so we ignore the message
return false
return
}
}
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) {
select {
case h.remoteAnswerCh <- answer:
return true
return
default:
// connection might not be ready yet to receive so we ignore the message
h.log.Debugf("OnRemoteAnswer skipping message because is not ready")
return false
}
}
func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
if err := h.sendAnswer(); err != nil {
return nil, err
}
return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil
case <-ctx.Done():
// closed externally
return nil, NewConnectionClosedError(h.config.Key)
h.log.Warnf("skipping remote answer message because receiver not ready")
return
}
}
@@ -147,43 +144,34 @@ func (h *Handshaker) sendOffer() error {
return ErrSignalIsNotReady
}
iceUFrag, icePwd := h.ice.GetLocalUserCredentials()
offer := OfferAnswer{
IceCredentials: IceCredentials{iceUFrag, icePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
offer.RelaySrvAddress = addr
}
offer := h.buildOfferAnswer()
h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
return h.signaler.SignalOffer(offer, h.config.Key)
}
func (h *Handshaker) sendAnswer() error {
h.log.Infof("sending answer")
uFrag, pwd := h.ice.GetLocalUserCredentials()
answer := h.buildOfferAnswer()
h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
return h.signaler.SignalAnswer(answer, h.config.Key)
}
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
SessionID: &sid,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
answer.RelaySrvAddress = addr
}
err = h.signaler.SignalAnswer(answer, h.config.Key)
if err != nil {
return err
}
return nil
return answer
}

View File

@@ -0,0 +1,47 @@
package peer
import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
)
const sessionIDSize = 5
type ICESessionID string
// NewICESessionID generates a new session ID for distinguishing sessions
func NewICESessionID() (ICESessionID, error) {
b := make([]byte, sessionIDSize)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
return ICESessionID(hex.EncodeToString(b)), nil
}
func ICESessionIDFromBytes(b []byte) (ICESessionID, error) {
if len(b) != sessionIDSize {
return "", fmt.Errorf("invalid session ID length: %d", len(b))
}
return ICESessionID(hex.EncodeToString(b)), nil
}
// Bytes returns the raw bytes of the session ID for protobuf serialization
func (id ICESessionID) Bytes() ([]byte, error) {
if len(id) == 0 {
return nil, fmt.Errorf("ICE session ID is empty")
}
b, err := hex.DecodeString(string(id))
if err != nil {
return nil, fmt.Errorf("invalid ICE session ID encoding: %w", err)
}
if len(b) != sessionIDSize {
return nil, fmt.Errorf("invalid ICE session ID length: expected %d bytes, got %d", sessionIDSize, len(b))
}
return b, nil
}
func (id ICESessionID) String() string {
return string(id)
}

View File

@@ -2,6 +2,7 @@ package peer
import (
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
signal "github.com/netbirdio/netbird/shared/signal/client"
@@ -45,6 +46,10 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
msg, err := signal.MarshalCredential(
s.wgPrivateKey,
offerAnswer.WgListenPort,
@@ -56,13 +61,13 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
bodyType,
offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr,
offerAnswer.RelaySrvAddress)
offerAnswer.RelaySrvAddress,
sessionIDBytes)
if err != nil {
return err
}
err = s.signal.Send(msg)
if err != nil {
if err = s.signal.Send(msg); err != nil {
return err
}

View File

@@ -42,8 +42,18 @@ type WorkerICE struct {
statusRecorder *Status
hasRelayOnLocally bool
agent *ice.Agent
muxAgent sync.Mutex
agent *ice.Agent
agentDialerCancel context.CancelFunc
agentConnecting bool // while it is true, drop all incoming offers
lastSuccess time.Time // with this avoid the too frequent ICE agent recreation
// remoteSessionID represents the peer's session identifier from the latest remote offer.
remoteSessionID ICESessionID
// sessionID is used to track the current session ID of the ICE agent
// increase by one when disconnecting the agent
// with it the remote peer can discard the already deprecated offer/answer
// Without it the remote peer may recreate a workable ICE connection
sessionID ICESessionID
muxAgent sync.Mutex
StunTurn []*stun.URI
@@ -57,6 +67,11 @@ type WorkerICE struct {
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
sessionID, err := NewICESessionID()
if err != nil {
return nil, err
}
w := &WorkerICE{
ctx: ctx,
log: log,
@@ -67,6 +82,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally,
lastKnownState: ice.ConnectionStateDisconnected,
sessionID: sessionID,
}
localUfrag, localPwd, err := icemaker.GenerateICECredentials()
@@ -79,15 +95,35 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
}
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE")
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock()
if w.agent != nil {
w.log.Debugf("agent already exists, skipping the offer")
if w.agentConnecting {
w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.agent != nil {
// backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return
}
w.log.Debugf("agent already exists, recreate the connection")
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
// todo consider to switch to Relay connection while establishing a new ICE connection
}
var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
preferredCandidateTypes = icemaker.CandidateTypesP2P()
@@ -96,36 +132,124 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}
w.log.Debugf("recreate ICE agent")
agentCtx, agentCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes)
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return
}
w.sentExtraSrflx = false
w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
w.muxAgent.Unlock()
w.log.Debugf("gather candidates")
err = w.agent.GatherCandidates()
if err != nil {
w.log.Debugf("failed to gather candidates: %s", err)
go w.connect(dialerCtx, agent, remoteOfferAnswer)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
if w.agent == nil {
w.log.Warnf("ICE Agent is not initialized yet")
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) InProgress() bool {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.agentConnecting
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) {
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
if err := agent.OnCandidate(w.onICECandidate); err != nil {
return nil, err
}
if err := agent.OnConnectionStateChange(w.onConnectionStateChange(agent, dialerCancel)); err != nil {
return nil, err
}
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
return nil, err
}
if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) SessionID() ICESessionID {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.sessionID
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("gather candidates")
if err := agent.GatherCandidates(); err != nil {
w.log.Warnf("failed to gather candidates: %s", err)
w.closeAgent(agent, w.agentDialerCancel)
return
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
w.log.Debugf("turn agent dial")
remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer)
remoteConn, err := w.turnAgentDial(ctx, remoteOfferAnswer)
if err != nil {
w.log.Debugf("failed to dial the remote peer: %s", err)
w.closeAgent(agent, w.agentDialerCancel)
return
}
w.log.Debugf("agent dial succeeded")
pair, err := w.agent.GetSelectedCandidatePair()
pair, err := agent.GetSelectedCandidatePair()
if err != nil {
w.closeAgent(agent, w.agentDialerCancel)
return
}
@@ -152,114 +276,38 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RelayedOnLocal: isRelayCandidate(pair.Local),
}
w.log.Debugf("on ICE conn is ready to use")
go w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.log.Infof("connection succeeded with offer session: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
if w.agent == nil {
w.log.Warnf("ICE Agent is not initialized yet")
return
w.agentConnecting = false
w.lastSuccess = time.Now()
if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
}
w.muxAgent.Unlock()
if candidateViaRoutes(candidate, haRoutes) {
return
}
err := w.agent.AddRemoteCandidate(candidate)
if err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
// todo: the potential problem is a race between the onConnectionStateChange
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) {
w.sentExtraSrflx = false
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
err = agent.OnCandidate(w.onICECandidate)
if err != nil {
return nil, err
}
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agentCancel)
default:
return
}
})
if err != nil {
return nil, err
}
err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair)
if err != nil {
return nil, err
}
err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency())
if err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
})
if err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
func (w *WorkerICE) closeAgent(agent *ice.Agent, cancel context.CancelFunc) {
cancel()
if w.agent == nil {
return
}
if err := w.agent.Close(); err != nil {
if err := agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
w.muxAgent.Lock()
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent {
w.agent = nil
w.agentConnecting = false
}
w.muxAgent.Unlock()
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -331,6 +379,32 @@ func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidat
w.config.Key)
}
func (w *WorkerICE) onConnectionStateChange(agent *ice.Agent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
return func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agent, dialerCancel)
default:
return
}
}
}
func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) {
if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true

View File

@@ -14,6 +14,7 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus"
@@ -32,7 +33,6 @@ import (
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
@@ -267,10 +267,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Helper()
dataDir := t.TempDir()
config := &types.Config{
Stuns: []*types.Host{},
TURNConfig: &types.TURNConfig{},
Signal: &types.Host{
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
@@ -290,6 +290,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
jobManager := server.NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
@@ -305,13 +306,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err
}

7
go.mod
View File

@@ -65,6 +65,7 @@ require (
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -102,7 +103,7 @@ require (
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/net v0.39.0
golang.org/x/oauth2 v0.24.0
golang.org/x/oauth2 v0.27.0
golang.org/x/sync v0.13.0
golang.org/x/term v0.31.0
google.golang.org/api v0.177.0
@@ -125,6 +126,7 @@ require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
@@ -144,7 +146,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.26 // indirect
github.com/containerd/containerd v1.7.27 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
@@ -220,6 +222,7 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/rymdport/portal v0.3.0 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect

20
go.sum
View File

@@ -66,11 +66,14 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0=
github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
@@ -116,6 +119,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -142,8 +146,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.26 h1:3cs8K2RHlMQaPifLqgRyI4VBkoldNdEw62cb7qQga7k=
github.com/containerd/containerd v1.7.26/go.mod h1:m4JU0E+h0ebbo9yXD7Hyt+sWnc8tChm7MudCjj4jRvQ=
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
@@ -416,6 +420,7 @@ github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/X
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk=
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@@ -516,6 +521,8 @@ github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7g
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/oapi-codegen/runtime v1.1.2 h1:P2+CubHq8fO4Q6fV1tqDBZHCwpVpvPg7oKiYzQgXIyI=
github.com/oapi-codegen/runtime v1.1.2/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg=
github.com/okta/okta-sdk-golang/v2 v2.18.0 h1:cfDasMb7CShbZvOrF6n+DnLevWwiHgedWMGJ8M8xKDc=
github.com/okta/okta-sdk-golang/v2 v2.18.0/go.mod h1:dz30v3ctAiMb7jpsCngGfQUAEGm1/NsWT92uTbNDQIs=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@@ -588,8 +595,8 @@ github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM=
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
@@ -627,6 +634,7 @@ github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.8.1/go.mod h1:o0Pch8wJ9BVSWGQMbra6iw0oQ5oktSIBaujf1rJH9Ns=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
@@ -868,8 +876,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M=
golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -2,88 +2,40 @@ package cmd
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/signal"
"path"
"slices"
"strings"
"time"
"syscall"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
const ManagementLegacyPort = 33073
var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
}
func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server) {
newServer = fn
}
var (
mgmtPort int
mgmtMetricsPort int
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
config *types.Config
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
config *nbconfig.Config
mgmtCmd = &cobra.Command{
Use: "management",
@@ -102,9 +54,9 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, types.MgmtConfigPath)
config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", types.MgmtConfigPath, err)
return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
}
if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
@@ -151,356 +103,38 @@ var (
return fmt.Errorf("failed creating datadir: %s: %v", config.Datadir, err)
}
}
appMetrics, err := telemetry.NewDefaultAppMetrics(cmd.Context())
if err != nil {
return err
}
err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
integrationMetrics, err := integrations.InitIntegrationMetrics(ctx, appMetrics)
if err != nil {
return err
}
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false)
if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager(appMetrics)
var idpManager idp.Manager
if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics)
if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
}
}
if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
return fmt.Errorf("initialize database: %s", err)
}
if config.DataStoreEncryptionKey != key {
log.WithContext(ctx).Infof("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
go func() {
if err := srv.Start(cmd.Context()); err != nil {
log.Fatalf("Server error: %v", err)
}
}()
stopChan := make(chan os.Signal, 1)
signal.Notify(stopChan, os.Interrupt, syscall.SIGTERM)
select {
case <-stopChan:
log.Info("Received shutdown signal, stopping server...")
err = srv.Stop()
if err != nil {
return fmt.Errorf("write out store encryption key: %s", err)
log.Errorf("Failed to stop server gracefully: %v", err)
}
case err := <-srv.Errors():
log.Fatalf("Server stopped unexpectedly: %v", err)
}
geo, err := geolocation.NewGeolocation(ctx, config.Datadir, !disableGeoliteUpdate)
if err != nil {
log.WithContext(ctx).Warnf("could not initialize geolocation service. proceeding without geolocation support: %v", err)
} else {
log.WithContext(ctx).Infof("geolocation service has been initialized from %s", config.Datadir)
}
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
if err != nil {
return fmt.Errorf("initialize integrated peer validator: %v", err)
}
permissionsManager := integrations.InitPermissionsManager(store)
userManager := users.NewManager(store)
extraSettingsManager := integrations.NewManager(eventStore)
settingsManager := settings.NewManager(store, userManager, extraSettingsManager, permissionsManager)
peersManager := peers.NewManager(store, permissionsManager)
proxyController := integrations.NewController(store)
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
if err != nil {
return fmt.Errorf("build default manager: %v", err)
}
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager, groupsManager)
trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
}
realipOpts := []realip.Option{
realip.WithTrustedPeers(trustedPeers),
realip.WithTrustedProxies(trustedHTTPProxies),
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
var certManager *autocert.Manager
var tlsConfig *tls.Config
tlsEnabled := false
if config.HttpConfig.LetsEncryptDomain != "" {
certManager, err = encryption.CreateCertManager(config.Datadir, config.HttpConfig.LetsEncryptDomain)
if err != nil {
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
tlsEnabled = true
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
if err != nil {
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
transportCredentials := credentials.NewTLS(tlsConfig)
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
tlsEnabled = true
}
authManager := auth.NewManager(store,
config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.AuthUserIDClaim,
config.GetAuthAudiences(),
config.HttpConfig.IdpSignKeyRefreshEnabled)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, integratedPeerValidator, proxyController, permissionsManager, peersManager, settingsManager)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
ephemeralManager := server.NewEphemeralManager(store, accountManager)
ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(ctx, store)
if err != nil {
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
if !disableMetrics {
idpManager := "disabled"
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
idpManager = config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
go metricsWorker.Run(ctx)
}
var compatListener net.Listener
if mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort)
if err != nil {
return err
}
log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler)
var listener net.Listener
if certManager != nil {
// a call to certManager.Listener() always creates a new listener so we do it once
cml := certManager.Listener()
if mgmtPort == 443 {
// CertManager, HTTP and gRPC API all on the same port
rootHandler = certManager.HTTPHandler(rootHandler)
listener = cml
} else {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), certManager.TLSConfig())
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
}
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(ctx, cml, certManager.HTTPHandler(nil))
}
} else if tlsConfig != nil {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig)
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
}
} else {
listener, err = net.Listen("tcp", fmt.Sprintf(":%d", mgmtPort))
if err != nil {
return fmt.Errorf("failed creating TCP listener on port %d: %v", mgmtPort, err)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
update := version.NewUpdate("nb/management")
update.SetDaemonVersion(version.NetbirdVersion())
update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
})
defer update.StopWatch()
SetupCloseHandler()
<-stopCh
integratedPeerValidator.Stop(ctx)
if geo != nil {
_ = geo.Stop()
}
ephemeralManager.Stop()
_ = appMetrics.Close()
_ = listener.Close()
if certManager != nil {
_ = certManager.Listener().Close()
}
gRPCAPIHandler.Stop()
_ = store.Close(ctx)
_ = eventStore.Close(ctx)
log.WithContext(ctx).Infof("stopped Management Service")
return nil
},
}
)
func unaryInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}
func notifyStop(ctx context.Context, msg string) {
select {
case stopCh <- 1:
log.WithContext(ctx).Error(msg)
default:
// stop has been already called, nothing to report
}
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(ctx, installationID)
if err != nil {
return "", err
}
return installationID, nil
}
func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
go func() {
err := grpcServer.Serve(listener)
if err != nil {
notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
}
}()
return listener, nil
}
func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
go func() {
err := http.Serve(httpListener, handler)
if err != nil {
notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err))
}
}()
}
func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
go func() {
var err error
if tlsEnabled {
err = http.Serve(listener, handler)
} else {
// the following magic is needed to support HTTP2 without TLS
// and still share a single port between gRPC and HTTP APIs
h1s := &http.Server{
Handler: h2c.NewHandler(handler, &http2.Server{}),
}
err = h1s.Serve(listener)
}
if err != nil {
select {
case stopCh <- 1:
log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err)
default:
// stop has been already called, nothing to report
}
}
}()
}
func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")
if request.ProtoMajor == 2 && grpcHeader {
gRPCHandler.ServeHTTP(writer, request)
} else {
httpHandler.ServeHTTP(writer, request)
}
})
}
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config, error) {
loadedConfig := &types.Config{}
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
loadedConfig := &nbconfig.Config{}
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil {
return nil, err
@@ -535,7 +169,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config,
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(types.NONE)) {
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE)) {
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
@@ -552,7 +186,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config,
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = types.DefaultDeviceAuthFlowScope
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = nbconfig.DefaultDeviceAuthFlowScope
}
}
@@ -573,10 +207,6 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config,
return loadedConfig, err
}
func updateMgmtConfig(ctx context.Context, path string, config *types.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
// OIDCConfigResponse used for parsing OIDC config response
type OIDCConfigResponse struct {
Issuer string `json:"issuer"`
@@ -619,25 +249,6 @@ func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigRespon
return config, nil
}
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
// Load server's certificate and private key
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
if err != nil {
return nil, err
}
// NewDefaultAppMetrics the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
},
}
return config, nil
}
func handleRebrand(cmd *cobra.Command) error {
var err error
if logFile == defaultLogFile {
@@ -649,7 +260,7 @@ func handleRebrand(cmd *cobra.Command) error {
}
}
}
if types.MgmtConfigPath == defaultMgmtConfig {
if nbconfig.MgmtConfigPath == defaultMgmtConfig {
if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir)
err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir)

View File

@@ -2,12 +2,10 @@ package cmd
import (
"fmt"
"os"
"os/signal"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server/types"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/version"
)
@@ -27,6 +25,12 @@ var (
disableGeoliteUpdate bool
idpSignKeyRefreshEnabled bool
userDeleteFromIDPEnabled bool
mgmtPort int
mgmtMetricsPort int
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
rootCmd = &cobra.Command{
Use: "netbird-mgmt",
@@ -42,8 +46,6 @@ var (
Long: "",
SilenceUsage: true,
}
// Execution control channel for stopCh signal
stopCh chan int
)
// Execute executes the root command.
@@ -52,11 +54,10 @@ func Execute() error {
}
func init() {
stopCh = make(chan int)
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&types.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")
@@ -80,15 +81,3 @@ func init() {
rootCmd.AddCommand(migrationCmd)
}
// SetupCloseHandler handles SIGTERM signal and exits with success
func SetupCloseHandler() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
for range c {
fmt.Println("\r- Ctrl+C pressed in Terminal")
stopCh <- 0
}
}()
}

View File

@@ -0,0 +1,204 @@
package server
// @note this file includes all the lower level dependencies, db, http and grpc BaseServer, metrics, logger, etc.
import (
"context"
"crypto/tls"
"net/http"
"net/netip"
"slices"
"time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
func (s *BaseServer) Metrics() telemetry.AppMetrics {
return Create(s, func() telemetry.AppMetrics {
appMetrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
log.Fatalf("error while creating app metrics: %s", err)
}
return appMetrics
})
}
func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store {
store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false)
if err != nil {
log.Fatalf("failed to create store: %v", err)
}
return store
})
}
func (s *BaseServer) EventStore() activity.Store {
return Create(s, func() activity.Store {
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
if err != nil {
log.Fatalf("failed to initialize integration metrics: %v", err)
}
eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
log.Fatalf("failed to initialize event store: %v", err)
}
if s.config.DataStoreEncryptionKey != key {
log.WithContext(context.Background()).Infof("update config with activity store key")
s.config.DataStoreEncryptionKey = key
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config)
if err != nil {
log.Fatalf("failed to update config with activity store: %v", err)
}
}
return eventStore
})
}
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
return httpAPIHandler
})
}
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
}
realipOpts := []realip.Option{
realip.WithTrustedPeers(trustedPeers),
realip.WithTrustedProxies(trustedHTTPProxies),
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
if s.config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if err != nil {
log.Fatalf("failed to create certificate manager: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
if err != nil {
log.Fatalf("cannot load TLS credentials: %v", err)
}
transportCredentials := credentials.NewTLS(tlsConfig)
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
return gRPCAPIHandler
})
}
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
// Load server's certificate and private key
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
if err != nil {
return nil, err
}
// NewDefaultAppMetrics the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
},
}
return config, nil
}
func unaryInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}

View File

@@ -1,10 +1,11 @@
package types
package config
import (
"net/netip"
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/util"
)
@@ -166,7 +167,7 @@ type ProviderConfig struct {
// StoreConfig contains Store configuration
type StoreConfig struct {
Engine Engine
Engine types.Engine
}
// ReverseProxy contains reverse proxy configuration in front of management.

View File

@@ -0,0 +1,55 @@
package server
import "fmt"
// Create a dependency and add it to the BaseServer's container. A string key identifier will be based on its type definition.
func Create[T any](s Server, createFunc func() T) T {
result, _ := maybeCreate(s, createFunc)
return result
}
// CreateNamed is the same as Create but will suffix the dependency string key identifier with a custom name.
// Useful if you want to have multiple named instances of the same object type.
func CreateNamed[T any](s Server, name string, createFunc func() T) T {
result, _ := maybeCreateNamed(s, name, createFunc)
return result
}
// Inject lets you override a specific service from outside the BaseServer itself.
// This is useful for tests
func Inject[T any](c Server, thing T) {
_, _ = maybeCreate(c, func() T {
return thing
})
}
// InjectNamed is like Inject() but with a custom name.
func InjectNamed[T any](c Server, name string, thing T) {
_, _ = maybeCreateKeyed(c, name, func() T {
return thing
})
}
func maybeCreate[T any](s Server, createFunc func() T) (result T, isNew bool) {
key := fmt.Sprintf("%T", (*T)(nil))[1:]
return maybeCreateKeyed(s, key, createFunc)
}
func maybeCreateNamed[T any](s Server, name string, createFunc func() T) (result T, isNew bool) {
key := fmt.Sprintf("%T:%s", (*T)(nil), name)[1:]
return maybeCreateKeyed(s, key, createFunc)
}
func maybeCreateKeyed[T any](s Server, key string, createFunc func() T) (result T, isNew bool) {
if t, ok := s.GetContainer(key); ok {
return t.(T), false
}
t := createFunc()
s.SetContainer(key, t)
return t, true
}

View File

@@ -0,0 +1,65 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
)
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
return Create(s, func() *server.PeersUpdateManager {
return server.NewPeersUpdateManager(s.Metrics())
})
}
func (s *BaseServer) JobManager() *server.JobManager {
return Create(s, func() *server.JobManager {
return server.NewJobManager(s.Metrics(), s.Store())
})
}
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())
if err != nil {
log.Errorf("failed to create integrated peer validator: %v", err)
}
return integratedPeerValidator
})
}
func (s *BaseServer) ProxyController() port_forwarding.Controller {
return Create(s, func() port_forwarding.Controller {
return integrations.NewController(s.Store())
})
}
func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager {
return Create(s, func() *server.TimeBasedAuthSecretsManager {
return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
})
}
func (s *BaseServer) AuthManager() auth.Manager {
return Create(s, func() auth.Manager {
return auth.NewManager(s.Store(),
s.config.HttpConfig.AuthIssuer,
s.config.HttpConfig.AuthAudience,
s.config.HttpConfig.AuthKeysLocation,
s.config.HttpConfig.AuthUserIDClaim,
s.config.GetAuthAudiences(),
s.config.HttpConfig.IdpSignKeyRefreshEnabled)
})
}
func (s *BaseServer) EphemeralManager() *server.EphemeralManager {
return Create(s, func() *server.EphemeralManager {
return server.NewEphemeralManager(s.Store(), s.AccountManager())
})
}

View File

@@ -0,0 +1,108 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
)
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
return Create(s, func() geolocation.Geolocation {
geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate)
if err != nil {
log.Fatalf("could not initialize geolocation service: %v", err)
}
log.Infof("geolocation service has been initialized from %s", s.config.Datadir)
return geo
})
}
func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager {
return integrations.InitPermissionsManager(s.Store())
})
}
func (s *BaseServer) UsersManager() users.Manager {
return Create(s, func() users.Manager {
return users.NewManager(s.Store())
})
}
func (s *BaseServer) SettingsManager() settings.Manager {
return Create(s, func() settings.Manager {
extraSettingsManager := integrations.NewManager(s.EventStore())
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
})
}
func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager {
return peers.NewManager(s.Store(), s.PermissionsManager())
})
}
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain,
s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
return accountManager
})
}
func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
if s.config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP manager: %v", err)
}
}
return idpManager
})
}
func (s *BaseServer) GroupsManager() groups.Manager {
return Create(s, func() groups.Manager {
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
})
}
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
})
}
func (s *BaseServer) RoutesManager() routers.Manager {
return Create(s, func() routers.Manager {
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
})
}
func (s *BaseServer) NetworksManager() networks.Manager {
return Create(s, func() networks.Manager {
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
})
}

View File

@@ -0,0 +1,341 @@
package server
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/encryption"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
const ManagementLegacyPort = 33073
type Server interface {
Start(ctx context.Context) error
Stop() error
Errors() <-chan error
GetContainer(key string) (any, bool)
SetContainer(key string, container any)
}
// Server holds the HTTP BaseServer instance.
// Add any additional fields you need, such as database connections, config, etc.
type BaseServer struct {
// config holds the server configuration
config *nbconfig.Config
// container of dependencies, each dependency is identified by a unique string.
container map[string]any
// AfterInit is a function that will be called after the server is initialized
afterInit []func(s *BaseServer)
disableMetrics bool
dnsDomain string
disableGeoliteUpdate bool
userDeleteFromIDPEnabled bool
mgmtSingleAccModeDomain string
mgmtMetricsPort int
mgmtPort int
listener net.Listener
certManager *autocert.Manager
update *version.Update
errCh chan error
wg sync.WaitGroup
cancel context.CancelFunc
}
// NewServer initializes and configures a new Server instance
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
return &BaseServer{
config: config,
container: make(map[string]any),
dnsDomain: dnsDomain,
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
disableMetrics: disableMetrics,
disableGeoliteUpdate: disableGeoliteUpdate,
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
mgmtPort: mgmtPort,
mgmtMetricsPort: mgmtMetricsPort,
}
}
func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {
s.afterInit = append(s.afterInit, fn)
}
// Start begins listening for HTTP requests on the configured address
func (s *BaseServer) Start(ctx context.Context) error {
srvCtx, cancel := context.WithCancel(ctx)
s.cancel = cancel
s.errCh = make(chan error, 4)
s.PeersManager()
s.GeoLocationManager()
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
if err != nil {
return fmt.Errorf("failed to expose metrics: %v", err)
}
s.EphemeralManager().LoadInitialPeers(srvCtx)
var tlsConfig *tls.Config
tlsEnabled := false
if s.config.HttpConfig.LetsEncryptDomain != "" {
s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if err != nil {
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
tlsEnabled = true
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
if err != nil {
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
return err
}
tlsEnabled = true
}
installationID, err := getInstallationID(srvCtx, s.Store())
if err != nil {
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
return err
}
if !s.disableMetrics {
idpManager := "disabled"
if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" {
idpManager = s.config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
go metricsWorker.Run(srvCtx)
}
var compatListener net.Listener
if s.mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = s.serveGRPC(srvCtx, s.GRPCServer(), ManagementLegacyPort)
if err != nil {
return err
}
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler())
switch {
case s.certManager != nil:
// a call to certManager.Listener() always creates a new listener so we do it once
cml := s.certManager.Listener()
if s.mgmtPort == 443 {
// CertManager, HTTP and gRPC API all on the same port
rootHandler = s.certManager.HTTPHandler(rootHandler)
s.listener = cml
} else {
s.listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort), s.certManager.TLSConfig())
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", s.mgmtPort, err)
}
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
s.serveHTTP(ctx, cml, s.certManager.HTTPHandler(nil))
}
case tlsConfig != nil:
s.listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort), tlsConfig)
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", s.mgmtPort, err)
}
default:
s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort))
if err != nil {
return fmt.Errorf("failed creating TCP listener on port %d: %v", s.mgmtPort, err)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
s.update = version.NewUpdate("nb/management")
s.update.SetDaemonVersion(version.NetbirdVersion())
s.update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
})
return nil
}
// Stop attempts a graceful shutdown, waiting up to 5 seconds for active connections to finish
func (s *BaseServer) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.IntegratedValidator().Stop(ctx)
if s.GeoLocationManager() != nil {
_ = s.GeoLocationManager().Stop()
}
s.EphemeralManager().Stop()
_ = s.Metrics().Close()
if s.listener != nil {
_ = s.listener.Close()
}
if s.certManager != nil {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
_ = s.Store().Close(ctx)
_ = s.EventStore().Close(ctx)
if s.update != nil {
s.update.StopWatch()
}
select {
case <-s.Errors():
log.WithContext(ctx).Infof("stopped Management Service")
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Done returns a channel that is closed when the server stops
func (s *BaseServer) Errors() <-chan error {
return s.errCh
}
// GetContainer retrieves a dependency from the BaseServer's container by its key
func (s *BaseServer) GetContainer(key string) (any, bool) {
container, exists := s.container[key]
return container, exists
}
// SetContainer stores a dependency in the BaseServer's container with the specified key
func (s *BaseServer) SetContainer(key string, container any) {
if _, exists := s.container[key]; exists {
log.Tracef("container with key %s already exists", key)
return
}
s.container[key] = container
log.Tracef("container with key %s set successfully", key)
}
func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")
if request.ProtoMajor == 2 && grpcHeader {
gRPCHandler.ServeHTTP(writer, request)
} else {
httpHandler.ServeHTTP(writer, request)
}
})
}
func (s *BaseServer) serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := grpcServer.Serve(listener)
if ctx.Err() != nil {
return
}
select {
case s.errCh <- err:
default:
}
}()
return listener, nil
}
func (s *BaseServer) serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := http.Serve(httpListener, handler)
if ctx.Err() != nil {
return
}
select {
case s.errCh <- err:
default:
}
}()
}
func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
var err error
if tlsEnabled {
err = http.Serve(listener, handler)
} else {
// the following magic is needed to support HTTP2 without TLS
// and still share a single port between gRPC and HTTP APIs
h1s := &http.Server{
Handler: h2c.NewHandler(handler, &http2.Server{}),
}
err = h1s.Serve(listener)
}
if ctx.Err() != nil {
return
}
select {
case s.errCh <- err:
default:
}
}()
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(ctx, installationID)
if err != nil {
return "", err
}
return installationID, nil
}

View File

@@ -68,6 +68,7 @@ type DefaultAccountManager struct {
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager
jobManager *JobManager
idpManager idp.Manager
cacheManager *nbcache.AccountUserDataCache
externalCacheManager nbcache.UserDataCache
@@ -174,6 +175,7 @@ func BuildManager(
ctx context.Context,
store store.Store,
peersUpdateManager *PeersUpdateManager,
jobManager *JobManager,
idpManager idp.Manager,
singleAccountModeDomain string,
dnsDomain string,
@@ -196,6 +198,7 @@ func BuildManager(
Store: store,
geo: geo,
peersUpdateManager: peersUpdateManager,
jobManager: jobManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
@@ -1952,20 +1955,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
}
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
var account *types.Account
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
account, err = transaction.GetAccount(ctx, accountId)
ok, domain, err := transaction.IsPrimaryAccount(ctx, accountId)
if err != nil {
return err
}
if account.IsDomainPrimaryAccount {
if ok {
return nil
}
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
// error is not a not found error
if handleNotFound(err) != nil {
@@ -1981,9 +1983,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := transaction.SaveAccount(ctx, account); err != nil {
if err := transaction.MarkAccountPrimary(ctx, accountId); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
@@ -1993,10 +1993,10 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return nil
})
if err != nil {
return nil, err
return err
}
return account, nil
return nil
}
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
@@ -2067,14 +2067,12 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
}
account, err := transaction.GetAccount(ctx, accountID)
err := transaction.UpdateAccountNetwork(ctx, accountID, newIPNet)
if err != nil {
return err
}
account.Network.Net = newIPNet
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthUpdate, accountID, "", "")
if err != nil {
return err
}
@@ -2094,10 +2092,6 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
takenIPs = append(takenIPs, newIP)
}
if err = transaction.SaveAccount(ctx, account); err != nil {
return err
}
for _, peer := range peers {
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)

View File

@@ -7,7 +7,6 @@ import (
"time"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -18,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
type ExternalCacheManager nbcache.UserDataCache
@@ -120,7 +120,10 @@ type Manager interface {
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
GetStore() store.Store
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
}

View File

@@ -2891,7 +2891,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), NewJobManager(nil, store), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, err
}
@@ -3250,11 +3250,13 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
assert.Equal(t, 0, len(account2.Users))
assert.Equal(t, 0, len(account2.SetupKeys))
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
account, err = manager.Store.GetAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
assert.Error(t, err, "should not be able to update a second account to primary")
}
@@ -3275,7 +3277,9 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain)
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
account, err = manager.Store.GetAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)

View File

@@ -178,6 +178,8 @@ const (
AccountNetworkRangeUpdated Activity = 87
PeerIPUpdated Activity = 88
JobCreatedByUser Activity = 89
AccountDeleted Activity = 99999
)
@@ -284,6 +286,8 @@ var activityMap = map[Activity]Code{
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
}
// StringCode returns a string code of the activity

View File

@@ -219,7 +219,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), NewJobManager(nil, store), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createDNSStore(t *testing.T) (store.Store, error) {

View File

@@ -2,7 +2,9 @@ package server
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"strings"
@@ -19,6 +21,7 @@ import (
"google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store"
@@ -44,7 +47,8 @@ type GRPCServer struct {
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager
config *types.Config
jobManager *JobManager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
@@ -56,10 +60,11 @@ type GRPCServer struct {
// NewServer creates a new Management server
func NewServer(
ctx context.Context,
config *types.Config,
config *nbconfig.Config,
accountManager account.Manager,
settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager,
jobManager *JobManager,
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
ephemeralManager *EphemeralManager,
@@ -73,10 +78,9 @@ func NewServer(
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
if err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels))
})
if err != nil {
}); err != nil {
return nil, err
}
}
@@ -85,6 +89,7 @@ func NewServer(
wgKey: key,
// peerKey -> event channel
peersUpdateManager: peersUpdateManager,
jobManager: jobManager,
accountManager: accountManager,
settingsManager: settingsManager,
config: config,
@@ -131,6 +136,46 @@ func getRealIP(ctx context.Context) net.IP {
return nil
}
func (s *GRPCServer) Job(srv proto.ManagementService_JobServer) error {
reqStart := time.Now()
ctx := srv.Context()
peerKey, err := s.handleHandshake(ctx, srv)
if err != nil {
return err
}
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil {
return status.Errorf(codes.Unauthenticated, "peer is not registered")
}
// Start background response handler
s.startResponseReceiver(ctx, srv)
// Prepare per-peer state
updates, err := s.jobManager.CreateJobChannel(ctx, accountID, peer.ID)
if err != nil {
return status.Errorf(codes.Internal, err.Error())
}
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
// Main loop: forward jobs to client
return s.sendJobsLoop(ctx, accountID, peerKey, peer, updates, srv)
}
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
@@ -146,7 +191,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if err != nil {
return err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
@@ -170,7 +214,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
@@ -190,10 +233,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return err
}
// Prepare per-peer state
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
s.ephemeralManager.OnPeerConnected(ctx, peer)
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
if s.appMetrics != nil {
@@ -208,6 +250,76 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
func (s *GRPCServer) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
hello, err := srv.Recv()
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "missing hello: %v", err)
}
jobReq := &proto.JobRequest{}
peerKey, err := s.parseRequest(ctx, hello, jobReq)
if err != nil {
return wgtypes.Key{}, err
}
return peerKey, nil
}
func (s *GRPCServer) startResponseReceiver(ctx context.Context, srv proto.ManagementService_JobServer) {
go func() {
for {
msg, err := srv.Recv()
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return
}
log.WithContext(ctx).Warnf("recv job response error: %v", err)
return
}
jobResp := &proto.JobResponse{}
if _, err := s.parseRequest(ctx, msg, jobResp); err != nil {
log.WithContext(ctx).Warnf("invalid job response: %v", err)
continue
}
if err := s.jobManager.HandleResponse(ctx, jobResp); err != nil {
log.WithContext(ctx).Errorf("handle job response failed: %v", err)
}
}
}()
}
func (s *GRPCServer) sendJobsLoop(
ctx context.Context,
accountID string,
peerKey wgtypes.Key,
peer *nbpeer.Peer,
updates <-chan *JobEvent,
srv proto.ManagementService_JobServer,
) error {
for {
select {
case job, open := <-updates:
if !open {
log.WithContext(ctx).Debugf("jobs channel for peer %s was closed", peerKey.String())
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
return nil
}
if err := s.sendJob(ctx, accountID, peerKey, peer, job, srv); err != nil {
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
log.WithContext(ctx).Warnf("send job failed: %v", err)
}
case <-ctx.Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
return ctx.Err()
}
}
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
@@ -225,7 +337,6 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
return err
}
@@ -248,7 +359,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
@@ -260,6 +371,27 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
return nil
}
// sendJob encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendJob(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, job *JobEvent, srv proto.ManagementService_JobServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, job.Request)
if err != nil {
log.WithContext(ctx).Errorf("failed to encrypt job for peer %s: %v", peerKey.String(), err)
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
return nil
}
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
return status.Errorf(codes.Internal, "failed sending job message")
}
log.WithContext(ctx).Debugf("sent an job to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
@@ -567,24 +699,24 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil
}
func ToResponseProto(configProto types.Protocol) proto.HostConfig_Protocol {
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto {
case types.UDP:
case nbconfig.UDP:
return proto.HostConfig_UDP
case types.DTLS:
case nbconfig.DTLS:
return proto.HostConfig_DTLS
case types.HTTP:
case nbconfig.HTTP:
return proto.HostConfig_HTTP
case types.HTTPS:
case nbconfig.HTTPS:
return proto.HostConfig_HTTPS
case types.TCP:
case nbconfig.TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
}
@@ -662,7 +794,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
}
}
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{
@@ -728,8 +860,8 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
var err error
var turnToken *Token
if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials {
turnToken, err = s.secretsManager.GenerateTurnToken()
if err != nil {
@@ -799,7 +931,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(types.NONE) {
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
}

View File

@@ -14,11 +14,11 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
// Handler is a handler that returns peers of the account
@@ -32,6 +32,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
}
// NewHandler creates a new peers Handler
@@ -41,6 +45,99 @@ func NewHandler(accountManager account.Manager) *Handler {
}
}
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
req := &api.JobRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
if err != nil {
util.WriteError(ctx, err, w)
return
}
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(ctx, err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
}
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
respBody := make([]*api.JobResponse, 0, len(jobs))
for _, job := range jobs {
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
return
}
respBody = append(respBody, resp)
}
util.WriteJSONObject(ctx, w, respBody)
}
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
jobID := vars["jobId"]
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
}
func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
peerToReturn := peer.Copy()
if peer.Status.Connected {
@@ -354,6 +451,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
}
return &api.Peer{
CreatedAt: peer.CreatedAt,
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
@@ -390,6 +488,7 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
}
return &api.PeerBatch{
CreatedAt: peer.CreatedAt,
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
@@ -419,6 +518,28 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
}
}
func toSingleJobResponse(job *types.Job) (*api.JobResponse, error) {
workload, err := job.BuildWorkloadResponse()
if err != nil {
return nil, err
}
var failed *string
if job.FailedReason != "" {
failed = &job.FailedReason
}
return &api.JobResponse{
Id: job.ID,
CreatedAt: job.CreatedAt,
CompletedAt: job.CompletedAt,
TriggeredBy: job.TriggeredBy,
Status: api.JobResponseStatus(job.Status),
FailedReason: failed,
Workload: *workload,
}, nil
}
func fqdn(peer *nbpeer.Peer, dnsDomain string) string {
fqdn := peer.FQDN(dnsDomain)
if fqdn == "" {

View File

@@ -119,6 +119,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
jobManager := server.NewJobManager(nil, store)
updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId)
done := make(chan struct{})
if validateUpdate {
@@ -138,7 +139,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}

View File

@@ -50,23 +50,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context,
defer unlock()
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
a, err := transaction.GetAccount(ctx, accountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return err
}
var extra *types.ExtraSettings
if a.Settings.Extra != nil {
extra = a.Settings.Extra
if settings.Extra != nil {
extra = settings.Extra
} else {
extra = &types.ExtraSettings{}
a.Settings.Extra = extra
settings.Extra = extra
}
extra.IntegratedValidator = validator
extra.IntegratedValidatorGroups = groups
return transaction.SaveAccount(ctx, a)
return transaction.SaveAccountSettings(ctx, accountID, settings)
})
}

View File

@@ -0,0 +1,223 @@
package server
import (
"context"
"fmt"
"sync"
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
log "github.com/sirupsen/logrus"
)
const jobChannelBuffer = 100
type JobEvent struct {
Job *types.Job
Request *proto.JobRequest
Response *proto.JobResponse
Done chan struct{} // closed when response arrives
StoreEvent func(meta map[string]any, peer *nbpeer.Peer)
}
type JobManager struct {
mu *sync.RWMutex
jobChannels map[string]chan *JobEvent // per-peer job streams
pending map[string]*JobEvent // jobID → event
responseWait time.Duration
metrics telemetry.AppMetrics
Store store.Store
}
func NewJobManager(metrics telemetry.AppMetrics, store store.Store) *JobManager {
return &JobManager{
jobChannels: make(map[string]chan *JobEvent),
pending: make(map[string]*JobEvent),
responseWait: 5 * time.Minute,
metrics: metrics,
mu: &sync.RWMutex{},
Store: store,
}
}
// CreateJobChannel creates or replaces a channel for a peer
func (jm *JobManager) CreateJobChannel(ctx context.Context, accountID, peerID string) (chan *JobEvent, error) {
// all pending jobs stored in db for this peer should be failed
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, "Pending job cleanup: marked as failed automatically due to being stuck too long"); err != nil {
return nil, err
}
jm.mu.Lock()
defer jm.mu.Unlock()
if ch, ok := jm.jobChannels[peerID]; ok {
close(ch)
delete(jm.jobChannels, peerID)
}
ch := make(chan *JobEvent, jobChannelBuffer)
jm.jobChannels[peerID] = ch
return ch, nil
}
// SendJob sends a job to a peer and tracks it as pending
func (jm *JobManager) SendJob(ctx context.Context, job *types.Job, storeEvent func(meta map[string]any, peer *nbpeer.Peer)) error {
jm.mu.RLock()
ch, ok := jm.jobChannels[job.PeerID]
jm.mu.RUnlock()
if !ok {
return fmt.Errorf("peer %s has no channel", job.PeerID)
}
req, err := job.ToStreamJobRequest()
if err != nil {
return err
}
event := &JobEvent{
Job: job,
Request: req,
Done: make(chan struct{}),
StoreEvent: storeEvent,
}
jm.mu.Lock()
jm.pending[string(req.ID)] = event
jm.mu.Unlock()
select {
case ch <- event:
case <-time.After(5 * time.Second):
jm.cleanup(ctx, string(req.ID), "timed out")
return fmt.Errorf("job channel full for peer %s", job.PeerID)
}
select {
case <-event.Done:
return nil
case <-time.After(jm.responseWait):
jm.cleanup(ctx, string(req.ID), "timed out")
return fmt.Errorf("job %s timed out", req.ID)
case <-ctx.Done():
jm.cleanup(ctx, string(req.ID), ctx.Err().Error())
return ctx.Err()
}
}
// HandleResponse marks a job as finished and moves it to completed
func (jm *JobManager) HandleResponse(ctx context.Context, resp *proto.JobResponse) error {
jm.mu.Lock()
defer jm.mu.Unlock()
event, ok := jm.pending[string(resp.ID)]
if !ok {
return fmt.Errorf("job %s not found", resp.ID)
}
fmt.Printf("we got this %+v\n", resp)
//update or create the store for job response
err := jm.saveJob(ctx, event.Job, resp, event.StoreEvent)
if err == nil {
event.Response = resp
}
close(event.Done)
delete(jm.pending, string(resp.ID))
return err
}
// CloseChannel closes a peers channel and cleans up its jobs
func (jm *JobManager) CloseChannel(ctx context.Context, accountID, peerID string) {
jm.mu.Lock()
defer jm.mu.Unlock()
if ch, ok := jm.jobChannels[peerID]; ok {
close(ch)
jm.jobChannels[peerID] = nil
delete(jm.jobChannels, peerID)
}
for jobID, ev := range jm.pending {
if ev.Job.PeerID == peerID {
// if the client disconnect and there is pending job then marke it as failed
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, "Time out"); err != nil {
log.WithContext(ctx).Errorf(err.Error())
}
delete(jm.pending, jobID)
}
}
}
// cleanup removes a pending job safely
func (jm *JobManager) cleanup(ctx context.Context, jobID string, reason string) {
jm.mu.Lock()
defer jm.mu.Unlock()
if ev, ok := jm.pending[jobID]; ok {
close(ev.Done)
if err := jm.Store.MarkPendingJobsAsFailed(ctx, ev.Job.AccountID, ev.Job.PeerID, reason); err != nil {
log.WithContext(ctx).Errorf(err.Error())
}
delete(jm.pending, jobID)
}
}
func (jm *JobManager) IsPeerConnected(peerID string) bool {
jm.mu.RLock()
defer jm.mu.RUnlock()
_, ok := jm.jobChannels[peerID]
return ok
}
func (jm *JobManager) IsPeerHasPendingJobs(peerID string) bool {
jm.mu.RLock()
defer jm.mu.RUnlock()
for _, ev := range jm.pending {
if ev.Job.PeerID == peerID {
return true
}
}
return false
}
func (jm *JobManager) saveJob(ctx context.Context, job *types.Job, response *proto.JobResponse, StoreEvent func(meta map[string]any, peer *nbpeer.Peer)) error {
var peer *nbpeer.Peer
var err error
var eventsToStore func()
// persist job in DB only if send succeeded
err = jm.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, job.AccountID, job.PeerID)
if err != nil {
return err
}
if err := transaction.CreateOrUpdatePeerJob(ctx, job, response); err != nil {
return err
}
jobMeta := map[string]any{
"job_id": job.ID,
"for_peer_id": job.PeerID,
"job_type": job.Workload.Type,
"job_status": job.Status,
"job_workload": job.Workload,
}
eventsToStore = func() {
StoreEvent(jobMeta, peer)
}
return nil
})
if err != nil {
return err
}
eventsToStore()
return nil
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -95,21 +96,21 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir()
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &config.Config{
Stuns: []*config.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &types.TURNConfig{
TURNConfig: &config.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*types.Host{{
Turns: []*config.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &types.Host{
Signal: &config.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},
@@ -332,7 +333,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testCases := []struct {
name string
inputFlow *types.DeviceAuthorizationFlow
inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
@@ -347,9 +348,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &types.DeviceAuthorizationFlow{
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: types.ProviderConfig{
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
@@ -358,9 +359,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
},
{
name: "Testing Full Device Flow Config",
inputFlow: &types.DeviceAuthorizationFlow{
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: types.ProviderConfig{
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
@@ -381,7 +382,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{
wgKey: testingServerKey,
config: &types.Config{
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
@@ -412,7 +413,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
func startManagementForTest(t *testing.T, testFile string, config *types.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
@@ -426,6 +427,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
}
peersUpdateManager := NewPeersUpdateManager(nil)
jobManager := NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{}
ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
@@ -449,7 +451,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
accountManager, err := BuildManager(ctx, store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
@@ -460,7 +462,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
if err != nil {
return nil, nil, "", cleanup, err
}
@@ -515,21 +517,21 @@ func testSyncStatusRace(t *testing.T) {
t.Skip()
dir := t.TempDir()
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &config.Config{
Stuns: []*config.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &types.TURNConfig{
TURNConfig: &config.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*types.Host{{
Turns: []*config.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &types.Host{
Signal: &config.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},
@@ -687,21 +689,21 @@ func Test_LoginPerformance(t *testing.T) {
t.Helper()
dir := t.TempDir()
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &config.Config{
Stuns: []*config.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &types.TURNConfig{
TURNConfig: &config.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*types.Host{{
Turns: []*config.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &types.Host{
Signal: &config.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},

View File

@@ -20,7 +20,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
@@ -30,6 +30,7 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -60,7 +61,7 @@ func setupTest(t *testing.T) *testSuite {
t.Fatalf("failed to create temp directory: %v", err)
}
config := &types.Config{}
config := &config.Config{}
_, err = util.ReadJson("testdata/management.json", config)
if err != nil {
t.Fatalf("failed to read management.json: %v", err)
@@ -158,7 +159,7 @@ func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClie
func startServer(
t *testing.T,
config *types.Config,
config *config.Config,
dataDir string,
testFile string,
) (*grpc.Server, net.Listener) {
@@ -175,6 +176,7 @@ func startServer(
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
jobManager := server.NewJobManager(nil, str)
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -201,6 +203,7 @@ func startServer(
context.Background(),
str,
peersUpdateManager,
jobManager,
nil,
"",
"netbird.selfhosted",
@@ -225,6 +228,7 @@ func startServer(
accountManager,
settingsMockManager,
peersUpdateManager,
jobManager,
secretsManager,
nil,
nil,

View File

@@ -10,7 +10,6 @@ import (
"google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -21,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var _ account.Manager = (*MockAccountManager)(nil)
@@ -114,7 +114,7 @@ type MockAccountManager struct {
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
GetStoreFunc func() store.Store
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
@@ -123,8 +123,32 @@ type MockAccountManager struct {
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
}
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
if am.CreatePeerJobFunc != nil {
return am.CreatePeerJobFunc(ctx, accountID, peerID, userID, job)
}
return status.Errorf(codes.Unimplemented, "method CreateJob is not implemented")
}
func (am *MockAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) {
if am.CreatePeerJobFunc != nil {
return am.GetAllPeerJobsFunc(ctx, accountID, userID, peerID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllJobs is not implemented")
}
func (am *MockAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) {
if am.CreatePeerJobFunc != nil {
return am.GetPeerJobByIDFunc(ctx, accountID, userID, peerID, jobID)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateJob is not implemented")
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(ctx, accountID, userID, group, true)
@@ -933,11 +957,11 @@ func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Cont
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
}
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
if am.UpdateToPrimaryAccountFunc != nil {
return am.UpdateToPrimaryAccountFunc(ctx, accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
return status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
}
func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) {

View File

@@ -785,7 +785,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes()
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), NewJobManager(nil, store), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createNSStore(t *testing.T) (store.Store, error) {

View File

@@ -333,6 +333,98 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return peer, nil
}
func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
// todo: Create permissions for job
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return err
}
if peerAccountID != accountID {
return status.NewPeerNotPartOfAccountError()
}
// check if peer connected
if !am.jobManager.IsPeerConnected(peerID) {
return status.Errorf(status.BadRequest, "peer not connected")
}
// check if already has pending jobs
if am.jobManager.IsPeerHasPendingJobs(peerID) {
return status.Errorf(status.BadRequest, "peer already hase pending job")
}
// try sending job first
if err := am.jobManager.SendJob(ctx, job, func(meta map[string]any, peer *nbpeer.Peer) {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.JobCreatedByUser, meta)
}); err != nil {
return status.Errorf(status.Internal, "failed to send job: %v", err)
}
return nil
}
func (am *DefaultAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) {
// todo: Create permissions for job
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return nil, err
}
if peerAccountID != accountID {
return []*types.Job{}, nil
}
accountJobs, err := am.Store.GetPeerJobs(accountID, peerID)
if err != nil {
return nil, err
}
return accountJobs, nil
}
func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) {
// todo: Create permissions for job
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return nil, err
}
if peerAccountID != accountID {
return &types.Job{}, nil
}
job, err := am.Store.GetPeerJobByID(accountID, jobID)
if err != nil {
return nil, err
}
return job, nil
}
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)

View File

@@ -25,6 +25,7 @@ import (
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -1063,16 +1064,16 @@ func TestToSyncResponse(t *testing.T) {
t.Fatal(err)
}
config := &types.Config{
Signal: &types.Host{
config := &config.Config{
Signal: &config.Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*types.Host{{URI: "stun.uri", Proto: types.UDP}},
TURNConfig: &types.TURNConfig{
Turns: []*types.Host{{URI: "turn.uri", Proto: types.UDP, Username: "turn-user", Password: "turn-pass"}},
Stuns: []*config.Host{{URI: "stun.uri", Proto: config.UDP}},
TURNConfig: &config.TURNConfig{
Turns: []*config.Host{{URI: "turn.uri", Proto: config.UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{
@@ -1273,7 +1274,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), NewJobManager(nil, s), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1353,7 +1354,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes()
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), NewJobManager(nil, s), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1501,7 +1502,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), NewJobManager(nil, s), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1575,7 +1576,7 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes()
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), NewJobManager(nil, s), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"

View File

@@ -290,6 +290,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return transaction.DeleteRoute(ctx, accountID, string(routeID))
})
if err != nil {
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
}
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())

View File

@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), NewJobManager(nil, store), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createRouterStore(t *testing.T) (store.Store, error) {

View File

@@ -34,18 +34,20 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
mysqlKeyQueryCondition = "`key` = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
mysqlKeyQueryCondition = "`key` = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountAndPeerIDQueryCondition = "account_id = ? and peer_id = ?"
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
)
// SqlStore represents an account storage backed by a Sql DB persisted to disk
@@ -106,6 +108,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -124,6 +127,71 @@ func GetKeyQueryCondition(s *SqlStore) string {
return keyQueryCondition
}
// SaveJob persists a job in DB
func (s *SqlStore) CreateOrUpdatePeerJob(ctx context.Context, job *types.Job, jobResponse *proto.JobResponse) error {
if err := job.ApplyResponse(jobResponse); err != nil {
return status.Errorf(status.Internal, err.Error())
}
fmt.Printf("new job or update %v\n", job)
result := s.db.
Model(&types.Job{}).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
DoUpdates: clause.AssignmentColumns([]string{"completed_at", "status", "failed_reason", "workload_workload_type", "workload_parameters", "workload_result"}),
}).Create(job)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to create job in store")
}
return nil
}
// job was pending for too long and has been cancelled
func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error {
now := time.Now().UTC()
result := s.db.
Model(&types.Job{}).
Where(accountAndPeerIDQueryCondition+"AND status = ?", accountID, peerID, types.JobStatusPending).
Updates(types.Job{
Status: types.JobStatusFailed,
FailedReason: reason,
CompletedAt: &now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark pending job as Failed in store")
}
return nil
}
// GetJobByID fetches job by ID
func (s *SqlStore) GetPeerJobByID(accountID, jobID string) (*types.Job, error) {
var job types.Job
err := s.db.
Where(accountAndIDQueryCondition, accountID, jobID).
First(&job).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "job %s not found", jobID)
}
return &job, err
}
// get all jobs
func (s *SqlStore) GetPeerJobs(accountID, peerID string) ([]*types.Job, error) {
var jobs []*types.Job
err := s.db.
Where(accountAndPeerIDQueryCondition, accountID, peerID).
Order("created_at DESC").
Find(&jobs).Error
if err != nil {
return nil, err
}
return jobs, nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring global lock")
@@ -2832,3 +2900,57 @@ func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFu
}()
return ctx, cancel
}
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
var info types.PrimaryAccountInfo
result := s.db.Model(&types.Account{}).
Select("is_domain_primary_account, domain").
Where(idQueryCondition, accountID).
Take(&info)
if result.Error != nil {
return false, "", status.Errorf(status.Internal, "failed to get account info: %v", result.Error)
}
return info.IsDomainPrimaryAccount, info.Domain, nil
}
func (s *SqlStore) MarkAccountPrimary(ctx context.Context, accountID string) error {
result := s.db.Model(&types.Account{}).
Where(idQueryCondition, accountID).
Update("is_domain_primary_account", true)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark account as primary: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark account as primary")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}
type accountNetworkPatch struct {
Network *types.Network `gorm:"embedded;embeddedPrefix:network_"`
}
func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error {
patch := accountNetworkPatch{
Network: &types.Network{Net: ipNet},
}
result := s.db.WithContext(ctx).
Model(&types.Account{}).
Where(idQueryCondition, accountID).
Updates(&patch)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update account network: %v", result.Error)
return status.Errorf(status.Internal, "failed to update account network")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/management/server/migration"
@@ -202,6 +203,13 @@ type Store interface {
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error)
MarkAccountPrimary(ctx context.Context, accountID string) error
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
GetPeerJobByID(accountID, jobID string) (*types.Job, error)
GetPeerJobs(accountID, peerID string) ([]*types.Job, error)
MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
CreateOrUpdatePeerJob(ctx context.Context, job *types.Job, jobResponse *proto.JobResponse) error
}
const (

View File

@@ -12,9 +12,9 @@ import (
log "github.com/sirupsen/logrus"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
@@ -33,8 +33,8 @@ type SecretsManager interface {
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
turnCfg *types.TURNConfig
relayCfg *types.Relay
turnCfg *nbconfig.TURNConfig
relayCfg *nbconfig.Relay
turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager
@@ -46,7 +46,7 @@ type TimeBasedAuthSecretsManager struct {
type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,

View File

@@ -13,6 +13,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
@@ -20,8 +21,8 @@ import (
"github.com/netbirdio/netbird/util"
)
var TurnTestHost = &types.Host{
Proto: types.UDP,
var TurnTestHost = &config.Host{
Proto: config.UDP,
URI: "turn:turn.netbird.io:77777",
Username: "username",
Password: "",
@@ -32,7 +33,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
rc := &types.Relay{
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@@ -43,10 +44,10 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*types.Host{TurnTestHost},
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
@@ -83,7 +84,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &types.Relay{
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@@ -95,10 +96,10 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*types.Host{TurnTestHost},
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
@@ -187,7 +188,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
peersManager := NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &types.Relay{
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@@ -198,10 +199,10 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*types.Host{TurnTestHost},
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)

View File

@@ -16,16 +16,16 @@ import (
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -89,6 +89,12 @@ type Account struct {
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
}
// this class is used by gorm only
type PrimaryAccountInfo struct {
IsDomainPrimaryAccount bool
Domain string
}
// Subclass used in gorm to only load network and not whole account
type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"`

View File

@@ -0,0 +1,216 @@
package types
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
type JobStatus string
const (
JobStatusPending JobStatus = "pending"
JobStatusSucceeded JobStatus = "succeeded"
JobStatusFailed JobStatus = "failed"
)
type JobType string
const (
JobTypeBundle JobType = "bundle"
)
type Job struct {
// ID is the primary identifier
ID string `gorm:"primaryKey"`
// CreatedAt when job was created (UTC)
CreatedAt time.Time `gorm:"autoCreateTime"`
// CompletedAt when job finished, null if still running
CompletedAt *time.Time
// TriggeredBy user that triggered this job
TriggeredBy string `gorm:"index"`
PeerID string `gorm:"index"`
AccountID string `gorm:"index"`
// Status of the job: pending, succeeded, failed
Status JobStatus `gorm:"index;type:varchar(50)"`
// FailedReason describes why the job failed (if failed)
FailedReason string
Workload Workload `gorm:"embedded;embeddedPrefix:workload_"`
}
type Workload struct {
Type JobType `gorm:"column:workload_type;index;type:varchar(50)"`
Parameters json.RawMessage `gorm:"type:json"`
Result json.RawMessage `gorm:"type:json"`
}
// NewJob creates a new job with default fields and validation
func NewJob(triggeredBy, accountID, peerID string, req *api.JobRequest) (*Job, error) {
if req == nil {
return nil, status.Errorf(status.BadRequest, "job request cannot be nil")
}
// Determine job type
jobTypeStr, err := req.Workload.Discriminator()
if err != nil {
return nil, status.Errorf(status.BadRequest, "could not determine job type: %v", err)
}
jobType := JobType(jobTypeStr)
if jobType == "" {
return nil, status.Errorf(status.BadRequest, "job type is required")
}
var workload Workload
switch jobType {
case JobTypeBundle:
if err := validateAndBuildBundleParams(req.Workload, &workload); err != nil {
return nil, status.Errorf(status.BadRequest, "%v", err)
}
default:
return nil, status.Errorf(status.BadRequest, "unsupported job type: %s", jobType)
}
return &Job{
ID: uuid.New().String(),
TriggeredBy: triggeredBy,
PeerID: peerID,
AccountID: accountID,
Status: JobStatusPending,
CreatedAt: time.Now().UTC(),
Workload: workload,
}, nil
}
func (j *Job) BuildWorkloadResponse() (*api.WorkloadResponse, error) {
var wl api.WorkloadResponse
switch j.Workload.Type {
case JobTypeBundle:
if err := j.buildBundleResponse(&wl); err != nil {
return nil, status.Errorf(status.InvalidArgument, err.Error())
}
return &wl, nil
default:
return nil, status.Errorf(status.InvalidArgument, "unknown job type: %v", j.Workload.Type)
}
}
func (j *Job) buildBundleResponse(wl *api.WorkloadResponse) error {
var p api.BundleParameters
if err := json.Unmarshal(j.Workload.Parameters, &p); err != nil {
return fmt.Errorf("invalid parameters for bundle job: %w", err)
}
var r api.BundleResult
if err := json.Unmarshal(j.Workload.Result, &r); err != nil {
return fmt.Errorf("invalid result for bundle job: %w", err)
}
if err := wl.FromBundleWorkloadResponse(api.BundleWorkloadResponse{
Type: api.WorkloadTypeBundle,
Parameters: p,
Result: r,
}); err != nil {
return fmt.Errorf("unknown job parameters: %v", err)
}
return nil
}
func validateAndBuildBundleParams(req api.WorkloadRequest, workload *Workload) error {
bundle, err := req.AsBundleWorkloadRequest()
if err != nil {
return fmt.Errorf("invalid parameters for bundle job")
}
// validate bundle_for_time <= 5 minutes
if bundle.Parameters.BundleForTime < 0 || bundle.Parameters.BundleForTime > 5 {
return fmt.Errorf("bundle_for_time must be between 0 and 5, got %d", bundle.Parameters.BundleForTime)
}
// validate log-file-count ≥ 1 and ≤ 1000
if bundle.Parameters.LogFileCount < 1 || bundle.Parameters.LogFileCount > 1000 {
return fmt.Errorf("log-file-count must be between 1 and 1000, got %d", bundle.Parameters.LogFileCount)
}
workload.Parameters, err = json.Marshal(bundle.Parameters)
if err != nil {
return fmt.Errorf("failed to marshal workload parameters: %w", err)
}
workload.Result = []byte("{}")
workload.Type = JobType(api.WorkloadTypeBundle)
return nil
}
// ApplyResponse validates and maps a proto.JobResponse into the Job fields.
func (j *Job) ApplyResponse(resp *proto.JobResponse) error {
if resp == nil {
return nil
}
now := time.Now().UTC()
j.CompletedAt = &now
switch resp.Status {
case proto.JobStatus_succeeded:
j.Status = JobStatusSucceeded
case proto.JobStatus_failed:
j.Status = JobStatusFailed
default:
j.Status = JobStatusPending
}
if len(resp.Reason) > 0 {
j.FailedReason = string(resp.Reason)
}
// Handle workload results (oneof)
var err error
switch r := resp.WorkloadResults.(type) {
case *proto.JobResponse_Bundle:
if j.Workload.Result, err = json.Marshal(r.Bundle); err != nil {
return fmt.Errorf("failed to marshal workload results: %w", err)
}
default:
return fmt.Errorf("unsupported workload response type: %T", r)
}
return nil
}
func (j *Job) ToStreamJobRequest() (*proto.JobRequest, error) {
switch j.Workload.Type {
case JobTypeBundle:
return j.buildStreamBundleResponse()
default:
return nil, status.Errorf(status.InvalidArgument, "unknown job type: %v", j.Workload.Type)
}
}
func (j *Job) buildStreamBundleResponse() (*proto.JobRequest, error) {
var p api.BundleParameters
if err := json.Unmarshal(j.Workload.Parameters, &p); err != nil {
return nil, fmt.Errorf("invalid parameters for bundle job: %w", err)
}
return &proto.JobRequest{
ID: []byte(j.ID),
WorkloadParameters: &proto.JobRequest_Bundle{
Bundle: &proto.BundleParameters{
BundleFor: p.BundleFor,
BundleForTime: int64(p.BundleForTime),
LogFileCount: int32(p.LogFileCount),
Anonymize: p.Anonymize,
},
},
}, nil
}

View File

@@ -111,7 +111,11 @@ type Route struct {
// EventMeta returns activity event meta related to the route
func (r *Route) EventMeta() map[string]any {
return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "domains": r.Domains.SafeString(), "peer_id": r.Peer, "peer_groups": r.PeerGroups}
domains := ""
if r.Domains != nil {
domains = r.Domains.SafeString()
}
return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "domains": domains, "peer_id": r.Peer, "peer_groups": r.PeerGroups}
}
// Copy copies a route object
@@ -181,7 +185,7 @@ func (r *Route) GetResourceID() ResID {
// If the route is dynamic, it returns the domains as comma-separated punycode-encoded string.
// If the route is not dynamic, it returns the network (prefix) string.
func (r *Route) NetString() string {
if r.IsDynamic() {
if r.IsDynamic() && r.Domains != nil {
return r.Domains.SafeString()
}
return r.Network.String()

View File

@@ -14,6 +14,7 @@ import (
type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)

View File

@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -27,9 +28,9 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
@@ -52,7 +53,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
level, _ := log.ParseLevel("debug")
log.SetLevel(level)
config := &types.Config{}
config := &config.Config{}
_, err := util.ReadJson("../../../management/server/testdata/management.json", config)
if err != nil {
t.Fatal(err)
@@ -70,6 +71,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
jobManager := mgmt.NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
@@ -107,7 +109,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
Return(true, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
@@ -115,7 +117,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
groupsManager := groups.NewManagerMock()
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}

View File

@@ -12,6 +12,7 @@ import (
gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
@@ -111,6 +112,25 @@ func (c *GrpcClient) ready() bool {
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
return c.withMgmtStream(ctx, func(ctx context.Context, serverPubKey wgtypes.Key) error {
return c.handleSyncStream(ctx, serverPubKey, sysInfo, msgHandler)
})
}
// Job wraps the real client's Job endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error {
return c.withMgmtStream(ctx, func(ctx context.Context, serverPubKey wgtypes.Key) error {
return c.handleJobStream(ctx, serverPubKey, msgHandler)
})
}
// withMgmtStream runs a streaming operation against the ManagementService
// It takes care of retries, connection readiness, and fetching server public key.
func (c *GrpcClient) withMgmtStream(
ctx context.Context,
handler func(ctx context.Context, serverPubKey wgtypes.Key) error,
) error {
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState()
@@ -128,7 +148,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return err
}
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
return handler(ctx, *serverPubKey)
}
err := backoff.Retry(operation, defaultBackoff(ctx))
@@ -139,12 +159,132 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return err
}
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
msgHandler func(msg *proto.SyncResponse) error) error {
func (c *GrpcClient) handleJobStream(
ctx context.Context,
serverPubKey wgtypes.Key,
msgHandler func(msg *proto.JobRequest) *proto.JobResponse,
) error {
stream, err := c.realClient.Job(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to open job stream: %v", err)
return err
}
// Handshake with the server
if err := c.sendHandshake(ctx, stream, serverPubKey); err != nil {
return err
}
log.WithContext(ctx).Debug("job stream handshake sent successfully")
// Main loop: receive, process, respond
for {
jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
log.WithContext(ctx).Info("job stream closed by server")
return nil
}
log.WithContext(ctx).Errorf("error receiving job request: %v", err)
return err
}
if jobReq == nil || len(jobReq.ID) == 0 {
log.WithContext(ctx).Debug("received unknown or empty job request, skipping")
continue
}
jobResp := c.processJobRequest(ctx, jobReq, msgHandler)
if err := c.sendJobResponse(ctx, stream, serverPubKey, jobResp); err != nil {
return err
}
}
}
// sendHandshake sends the initial handshake message
func (c *GrpcClient) sendHandshake(ctx context.Context, stream proto.ManagementService_JobClient, serverPubKey wgtypes.Key) error {
handshakeReq := &proto.JobRequest{
ID: []byte(uuid.New().String()),
}
encHello, err := encryption.EncryptMessage(serverPubKey, c.key, handshakeReq)
if err != nil {
log.WithContext(ctx).Errorf("failed to encrypt handshake message: %v", err)
return err
}
return stream.Send(&proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encHello,
})
}
// receiveJobRequest waits for and decrypts a job request
func (c *GrpcClient) receiveJobRequest(
ctx context.Context,
stream proto.ManagementService_JobClient,
serverPubKey wgtypes.Key,
) (*proto.JobRequest, error) {
encryptedMsg, err := stream.Recv()
if err != nil {
return nil, err
}
jobReq := &proto.JobRequest{}
if err := encryption.DecryptMessage(serverPubKey, c.key, encryptedMsg.Body, jobReq); err != nil {
log.WithContext(ctx).Warnf("failed to decrypt job request: %v", err)
return nil, err
}
return jobReq, nil
}
// processJobRequest executes the handler and ensures a valid response
func (c *GrpcClient) processJobRequest(
ctx context.Context,
jobReq *proto.JobRequest,
msgHandler func(msg *proto.JobRequest) *proto.JobResponse,
) *proto.JobResponse {
jobResp := msgHandler(jobReq)
if jobResp == nil {
jobResp = &proto.JobResponse{
ID: jobReq.ID,
Status: proto.JobStatus_failed,
Reason: []byte("handler returned nil response"),
}
log.WithContext(ctx).Warnf("job handler returned nil for job %s", string(jobReq.ID))
}
return jobResp
}
// sendJobResponse encrypts and sends a job response
func (c *GrpcClient) sendJobResponse(
ctx context.Context,
stream proto.ManagementService_JobClient,
serverPubKey wgtypes.Key,
resp *proto.JobResponse,
) error {
encResp, err := encryption.EncryptMessage(serverPubKey, c.key, resp)
if err != nil {
log.WithContext(ctx).Errorf("failed to encrypt job response for job %s: %v", string(resp.ID), err)
return err
}
if err := stream.Send(&proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encResp,
}); err != nil {
log.WithContext(ctx).Errorf("failed to send job response for job %s: %v", string(resp.ID), err)
return err
}
log.WithContext(ctx).Debugf("job response sent successfully for job %s", string(resp.ID))
return nil
}
func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo)
stream, err := c.connectToSyncStream(ctx, serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -157,7 +297,7 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
c.notifyConnected()
// blocking until error
err = c.receiveEvents(stream, serverPubKey, msgHandler)
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
if err != nil {
c.notifyDisconnected(err)
s, _ := gstatus.FromError(err)
@@ -186,7 +326,7 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
stream, err := c.connectToSyncStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
@@ -219,7 +359,7 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
return decryptedResp.GetNetworkMap(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
func (c *GrpcClient) connectToSyncStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
myPrivateKey := c.key
@@ -238,7 +378,7 @@ func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.K
return sync, nil
}
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
for {
update, err := stream.Recv()
if err == io.EOF {

View File

@@ -20,6 +20,7 @@ type MockClient struct {
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
}
func (m *MockClient) IsHealthy() bool {
@@ -40,6 +41,13 @@ func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return m.SyncFunc(ctx, sysInfo, msgHandler)
}
func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error {
if m.JobFunc == nil {
return nil
}
return m.JobFunc(ctx, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil

View File

@@ -11,6 +11,6 @@ fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@4a1477f6a8ba6ca8115cc23bb2fb67f0b9fca18e
go install github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen@latest
oapi-codegen --config cfg.yaml openapi.yml
cd "$old_pwd"
cd "$old_pwd"

View File

@@ -34,6 +34,111 @@ tags:
x-cloud-only: true
components:
schemas:
WorkloadType:
type: string
enum:
- bundle
BundleParameters:
type: object
properties:
bundle_for:
type: boolean
example: true
bundle_for_time:
type: integer
minimum: 0
example: 2
log_file_count:
type: integer
minimum: 0
example: 100
anonymize:
type: boolean
example: false
required:
- bundle_for
- bundle_for_time
- log_file_count
- anonymize
BundleResult:
type: object
properties:
upload_key:
type: string
example: "upload_key_123"
nullable: true
BundleWorkloadRequest:
type: object
properties:
type:
$ref: '#/components/schemas/WorkloadType'
parameters:
$ref: '#/components/schemas/BundleParameters'
required:
- type
- parameters
BundleWorkloadResponse:
type: object
properties:
type:
$ref: '#/components/schemas/WorkloadType'
parameters:
$ref: '#/components/schemas/BundleParameters'
result:
$ref: '#/components/schemas/BundleResult'
required:
- type
- parameters
- result
WorkloadRequest:
oneOf:
- $ref: '#/components/schemas/BundleWorkloadRequest'
discriminator:
propertyName: type
mapping:
bundle: '#/components/schemas/BundleWorkloadRequest'
WorkloadResponse:
oneOf:
- $ref: '#/components/schemas/BundleWorkloadResponse'
discriminator:
propertyName: type
mapping:
bundle: '#/components/schemas/BundleWorkloadResponse'
JobRequest:
type: object
properties:
workload:
$ref: '#/components/schemas/WorkloadRequest'
required:
- workload
JobResponse:
type: object
properties:
id:
type: string
created_at:
type: string
format: date-time
completed_at:
type: string
format: date-time
nullable: true
triggered_by:
type: string
status:
type: string
enum: [pending, succeeded, failed]
failed_reason:
type: string
nullable: true
workload:
$ref: '#/components/schemas/WorkloadResponse'
required:
- id
- created_at
- status
- triggered_by
- workload
Account:
type: object
properties:
@@ -369,6 +474,11 @@ components:
- $ref: '#/components/schemas/PeerMinimum'
- type: object
properties:
created_at:
description: Peer creation date (UTC)
type: string
format: date-time
example: "2023-05-05T09:00:35.477782Z"
ip:
description: Peer's IP address
type: string
@@ -471,6 +581,7 @@ components:
- connected
- connection_ip
- country_code
- created_at
- dns_label
- geoname_id
- groups
@@ -544,11 +655,17 @@ components:
- $ref: '#/components/schemas/Peer'
- type: object
properties:
created_at:
description: Peer creation date (UTC)
type: string
format: date-time
example: "2023-05-05T09:00:35.477782Z"
accessible_peers_count:
description: Number of accessible peers
type: integer
example: 5
required:
- created_at
- accessible_peers_count
SetupKeyBase:
type: object
@@ -2158,6 +2275,108 @@ security:
- BearerAuth: [ ]
- TokenAuth: [ ]
paths:
/api/peers/{peerId}/jobs:
get:
summary: List Jobs
description: Retrieve all jobs for a given peer
tags: [ Jobs ]
security:
- BearerAuth: []
- TokenAuth: []
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
responses:
'200':
description: List of jobs
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/JobResponse'
'400':
$ref: '#/components/responses/bad_request'
'401':
$ref: '#/components/responses/requires_authentication'
'403':
$ref: '#/components/responses/forbidden'
'500':
$ref: '#/components/responses/internal_error'
post:
summary: Create Job
description: Create a new job for a given peer
tags: [ Jobs ]
security:
- BearerAuth: []
- TokenAuth: []
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
requestBody:
description: Create job request
content:
application/json:
schema:
$ref: '#/components/schemas/JobRequest'
required: true
responses:
'201':
description: Job created
content:
application/json:
schema:
$ref: '#/components/schemas/JobResponse'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers/{peerId}/jobs/{jobId}:
get:
summary: Get Job
description: Retrieve details of a specific job
tags: [ Jobs ]
security:
- BearerAuth: []
- TokenAuth: []
parameters:
- in: path
name: peerId
required: true
schema:
type: string
- in: path
name: jobId
required: true
schema:
type: string
responses:
'200':
description: A Job object
content:
application/json:
schema:
$ref: '#/components/schemas/JobResponse'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/accounts:
get:
summary: List all Accounts

View File

@@ -1,10 +1,14 @@
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/deepmap/oapi-codegen version v1.11.1-0.20220912230023-4a1477f6a8ba DO NOT EDIT.
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.0 DO NOT EDIT.
package api
import (
"encoding/json"
"errors"
"time"
"github.com/oapi-codegen/runtime"
)
const (
@@ -104,6 +108,13 @@ const (
IngressPortAllocationRequestPortRangeProtocolUdp IngressPortAllocationRequestPortRangeProtocol = "udp"
)
// Defines values for JobResponseStatus.
const (
JobResponseStatusFailed JobResponseStatus = "failed"
JobResponseStatusPending JobResponseStatus = "pending"
JobResponseStatusSucceeded JobResponseStatus = "succeeded"
)
// Defines values for NameserverNsType.
const (
NameserverNsTypeUdp NameserverNsType = "udp"
@@ -178,6 +189,11 @@ const (
UserStatusInvited UserStatus = "invited"
)
// Defines values for WorkloadType.
const (
WorkloadTypeBundle WorkloadType = "bundle"
)
// Defines values for GetApiEventsNetworkTrafficParamsType.
const (
GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP"
@@ -337,6 +353,32 @@ type AvailablePorts struct {
Udp int `json:"udp"`
}
// BundleParameters defines model for BundleParameters.
type BundleParameters struct {
Anonymize bool `json:"anonymize"`
BundleFor bool `json:"bundle_for"`
BundleForTime int `json:"bundle_for_time"`
LogFileCount int `json:"log_file_count"`
}
// BundleResult defines model for BundleResult.
type BundleResult struct {
UploadKey *string `json:"upload_key"`
}
// BundleWorkloadRequest defines model for BundleWorkloadRequest.
type BundleWorkloadRequest struct {
Parameters BundleParameters `json:"parameters"`
Type WorkloadType `json:"type"`
}
// BundleWorkloadResponse defines model for BundleWorkloadResponse.
type BundleWorkloadResponse struct {
Parameters BundleParameters `json:"parameters"`
Result BundleResult `json:"result"`
Type WorkloadType `json:"type"`
}
// Checks List of objects that perform the actual checks
type Checks struct {
// GeoLocationCheck Posture check for geo location
@@ -643,6 +685,25 @@ type IngressPortAllocationRequestPortRange struct {
// IngressPortAllocationRequestPortRangeProtocol The protocol accepted by the port range
type IngressPortAllocationRequestPortRangeProtocol string
// JobRequest defines model for JobRequest.
type JobRequest struct {
Workload WorkloadRequest `json:"workload"`
}
// JobResponse defines model for JobResponse.
type JobResponse struct {
CompletedAt *time.Time `json:"completed_at"`
CreatedAt time.Time `json:"created_at"`
FailedReason *string `json:"failed_reason"`
Id string `json:"id"`
Status JobResponseStatus `json:"status"`
TriggeredBy string `json:"triggered_by"`
Workload WorkloadResponse `json:"workload"`
}
// JobResponseStatus defines model for JobResponse.Status.
type JobResponseStatus string
// Location Describe geographical location information
type Location struct {
// CityName Commonly used English name of the city
@@ -1030,6 +1091,9 @@ type Peer struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
@@ -1114,6 +1178,9 @@ type PeerBatch struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
@@ -1814,6 +1881,19 @@ type UserRequest struct {
Role string `json:"role"`
}
// WorkloadRequest defines model for WorkloadRequest.
type WorkloadRequest struct {
union json.RawMessage
}
// WorkloadResponse defines model for WorkloadResponse.
type WorkloadResponse struct {
union json.RawMessage
}
// WorkloadType defines model for WorkloadType.
type WorkloadType string
// GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic.
type GetApiEventsNetworkTrafficParams struct {
// Page Page number
@@ -1931,6 +2011,9 @@ type PostApiPeersPeerIdIngressPortsJSONRequestBody = IngressPortAllocationReques
// PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody defines body for PutApiPeersPeerIdIngressPortsAllocationId for application/json ContentType.
type PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody = IngressPortAllocationRequest
// PostApiPeersPeerIdJobsJSONRequestBody defines body for PostApiPeersPeerIdJobs for application/json ContentType.
type PostApiPeersPeerIdJobsJSONRequestBody = JobRequest
// PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType.
type PostApiPoliciesJSONRequestBody = PolicyUpdate
@@ -1963,3 +2046,121 @@ type PutApiUsersUserIdJSONRequestBody = UserRequest
// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType.
type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest
// AsBundleWorkloadRequest returns the union data inside the WorkloadRequest as a BundleWorkloadRequest
func (t WorkloadRequest) AsBundleWorkloadRequest() (BundleWorkloadRequest, error) {
var body BundleWorkloadRequest
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromBundleWorkloadRequest overwrites any union data inside the WorkloadRequest as the provided BundleWorkloadRequest
func (t *WorkloadRequest) FromBundleWorkloadRequest(v BundleWorkloadRequest) error {
v.Type = "bundle"
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeBundleWorkloadRequest performs a merge with any union data inside the WorkloadRequest, using the provided BundleWorkloadRequest
func (t *WorkloadRequest) MergeBundleWorkloadRequest(v BundleWorkloadRequest) error {
v.Type = "bundle"
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t WorkloadRequest) Discriminator() (string, error) {
var discriminator struct {
Discriminator string `json:"type"`
}
err := json.Unmarshal(t.union, &discriminator)
return discriminator.Discriminator, err
}
func (t WorkloadRequest) ValueByDiscriminator() (interface{}, error) {
discriminator, err := t.Discriminator()
if err != nil {
return nil, err
}
switch discriminator {
case "bundle":
return t.AsBundleWorkloadRequest()
default:
return nil, errors.New("unknown discriminator value: " + discriminator)
}
}
func (t WorkloadRequest) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
return b, err
}
func (t *WorkloadRequest) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
return err
}
// AsBundleWorkloadResponse returns the union data inside the WorkloadResponse as a BundleWorkloadResponse
func (t WorkloadResponse) AsBundleWorkloadResponse() (BundleWorkloadResponse, error) {
var body BundleWorkloadResponse
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromBundleWorkloadResponse overwrites any union data inside the WorkloadResponse as the provided BundleWorkloadResponse
func (t *WorkloadResponse) FromBundleWorkloadResponse(v BundleWorkloadResponse) error {
v.Type = "bundle"
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeBundleWorkloadResponse performs a merge with any union data inside the WorkloadResponse, using the provided BundleWorkloadResponse
func (t *WorkloadResponse) MergeBundleWorkloadResponse(v BundleWorkloadResponse) error {
v.Type = "bundle"
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t WorkloadResponse) Discriminator() (string, error) {
var discriminator struct {
Discriminator string `json:"type"`
}
err := json.Unmarshal(t.union, &discriminator)
return discriminator.Discriminator, err
}
func (t WorkloadResponse) ValueByDiscriminator() (interface{}, error) {
discriminator, err := t.Discriminator()
if err != nil {
return nil, err
}
switch discriminator {
case "bundle":
return t.AsBundleWorkloadResponse()
default:
return nil, errors.New("unknown discriminator value: " + discriminator)
}
}
func (t WorkloadResponse) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
return b, err
}
func (t *WorkloadResponse) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
return err
}

File diff suppressed because it is too large Load Diff

View File

@@ -48,6 +48,9 @@ service ManagementService {
// Logout logs out the peer and removes it from the management server
rpc Logout(EncryptedMessage) returns (Empty) {}
// Executes a job on a target peer (e.g., debug bundle)
rpc Job(stream EncryptedMessage) returns (stream EncryptedMessage) {}
}
message EncryptedMessage {
@@ -60,6 +63,42 @@ message EncryptedMessage {
int32 version = 3;
}
message JobRequest {
bytes ID = 1;
oneof workload_parameters {
BundleParameters bundle = 10;
//OtherParameters other = 11;
}
}
enum JobStatus {
unknown_status = 0; //placeholder
succeeded = 1;
failed = 2;
}
message JobResponse{
bytes ID = 1;
JobStatus status=2;
bytes Reason=3;
oneof workload_results {
BundleResult bundle = 10;
//OtherResult other = 11;
}
}
message BundleParameters {
bool bundle_for = 1;
int64 bundle_for_time = 2;
int32 log_file_count = 3;
bool anonymize = 4;
}
message BundleResult {
string upload_key = 1;
}
message SyncRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;

View File

@@ -50,6 +50,8 @@ type ManagementServiceClient interface {
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
// Logout logs out the peer and removes it from the management server
Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
// Executes a job on a target peer (e.g., debug bundle)
Job(ctx context.Context, opts ...grpc.CallOption) (ManagementService_JobClient, error)
}
type managementServiceClient struct {
@@ -155,6 +157,37 @@ func (c *managementServiceClient) Logout(ctx context.Context, in *EncryptedMessa
return out, nil
}
func (c *managementServiceClient) Job(ctx context.Context, opts ...grpc.CallOption) (ManagementService_JobClient, error) {
stream, err := c.cc.NewStream(ctx, &ManagementService_ServiceDesc.Streams[1], "/management.ManagementService/Job", opts...)
if err != nil {
return nil, err
}
x := &managementServiceJobClient{stream}
return x, nil
}
type ManagementService_JobClient interface {
Send(*EncryptedMessage) error
Recv() (*EncryptedMessage, error)
grpc.ClientStream
}
type managementServiceJobClient struct {
grpc.ClientStream
}
func (x *managementServiceJobClient) Send(m *EncryptedMessage) error {
return x.ClientStream.SendMsg(m)
}
func (x *managementServiceJobClient) Recv() (*EncryptedMessage, error) {
m := new(EncryptedMessage)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// ManagementServiceServer is the server API for ManagementService service.
// All implementations must embed UnimplementedManagementServiceServer
// for forward compatibility
@@ -191,6 +224,8 @@ type ManagementServiceServer interface {
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
// Logout logs out the peer and removes it from the management server
Logout(context.Context, *EncryptedMessage) (*Empty, error)
// Executes a job on a target peer (e.g., debug bundle)
Job(ManagementService_JobServer) error
mustEmbedUnimplementedManagementServiceServer()
}
@@ -222,6 +257,9 @@ func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *Encrypted
func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
}
func (UnimplementedManagementServiceServer) Job(ManagementService_JobServer) error {
return status.Errorf(codes.Unimplemented, "method Job not implemented")
}
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -382,6 +420,32 @@ func _ManagementService_Logout_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _ManagementService_Job_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(ManagementServiceServer).Job(&managementServiceJobServer{stream})
}
type ManagementService_JobServer interface {
Send(*EncryptedMessage) error
Recv() (*EncryptedMessage, error)
grpc.ServerStream
}
type managementServiceJobServer struct {
grpc.ServerStream
}
func (x *managementServiceJobServer) Send(m *EncryptedMessage) error {
return x.ServerStream.SendMsg(m)
}
func (x *managementServiceJobServer) Recv() (*EncryptedMessage, error) {
m := new(EncryptedMessage)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -424,6 +488,12 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{
Handler: _ManagementService_Sync_Handler,
ServerStreams: true,
},
{
StreamName: "Job",
Handler: _ManagementService_Job_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "management.proto",
}

View File

@@ -52,7 +52,7 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
}
// MarshalCredential marshal a Credential instance and returns a Message object
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string) (*proto.Message, error) {
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string, sessionID []byte) (*proto.Message, error) {
return &proto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey,
@@ -66,6 +66,7 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credenti
RosenpassServerAddr: rosenpassAddr,
},
RelayServerAddress: relaySrvAddress,
SessionId: sessionID,
},
}, nil
}

View File

@@ -45,19 +45,10 @@ type GrpcClient struct {
connStateCallbackLock sync.RWMutex
onReconnectedListenerFn func()
}
func (c *GrpcClient) StreamConnected() bool {
return c.status == StreamConnected
}
func (c *GrpcClient) GetStatus() Status {
return c.status
}
// Close Closes underlying connections to the Signal Exchange
func (c *GrpcClient) Close() error {
return c.signalConn.Close()
decryptionWorker *Worker
decryptionWorkerCancel context.CancelFunc
decryptionWg sync.WaitGroup
}
// NewClient creates a new Signal client
@@ -93,6 +84,25 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
}, nil
}
func (c *GrpcClient) StreamConnected() bool {
return c.status == StreamConnected
}
func (c *GrpcClient) GetStatus() Status {
return c.status
}
// Close Closes underlying connections to the Signal Exchange
func (c *GrpcClient) Close() error {
if c.decryptionWorkerCancel != nil {
c.decryptionWorkerCancel()
}
c.decryptionWg.Wait()
c.decryptionWorker = nil
return c.signalConn.Close()
}
// SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock()
@@ -148,8 +158,12 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
log.Infof("connected to the Signal Service stream")
c.notifyConnected()
// Start worker pool if not already started
c.startEncryptionWorker(msgHandler)
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler)
err = c.receive(stream)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
log.Debugf("signal connection context has been canceled, this usually indicates shutdown")
@@ -174,6 +188,7 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
return nil
}
func (c *GrpcClient) notifyStreamDisconnected() {
c.mux.Lock()
defer c.mux.Unlock()
@@ -382,11 +397,11 @@ func (c *GrpcClient) Send(msg *proto.Message) error {
}
// receive receives messages from other peers coming through the Signal Exchange
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
msgHandler func(msg *proto.Message) error) error {
// and distributes them to worker threads for processing
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) error {
for {
msg, err := stream.Recv()
// Handle errors immediately
switch s, ok := status.FromError(err); {
case ok && s.Code() == codes.Canceled:
log.Debugf("stream canceled (usually indicates shutdown)")
@@ -398,24 +413,37 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
log.Debugf("Signal Service stream closed by server")
return err
case err != nil:
log.Errorf("Stream receive error: %v", err)
return err
}
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg)
if err != nil {
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
if msg == nil {
continue
}
err = msgHandler(decryptedMessage)
if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
// todo send something??
if err := c.decryptionWorker.AddMsg(c.ctx, msg); err != nil {
log.Errorf("failed to add message to decryption worker: %v", err)
}
}
}
func (c *GrpcClient) startEncryptionWorker(handler func(msg *proto.Message) error) {
if c.decryptionWorker != nil {
return
}
c.decryptionWorker = NewWorker(c.decryptMessage, handler)
workerCtx, workerCancel := context.WithCancel(context.Background())
c.decryptionWorkerCancel = workerCancel
c.decryptionWg.Add(1)
go func() {
defer workerCancel()
c.decryptionWorker.Work(workerCtx)
c.decryptionWg.Done()
}()
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()

View File

@@ -0,0 +1,55 @@
package client
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/signal/proto"
)
type Worker struct {
decryptMessage func(msg *proto.EncryptedMessage) (*proto.Message, error)
handler func(msg *proto.Message) error
encryptedMsgPool chan *proto.EncryptedMessage
}
func NewWorker(decryptFn func(msg *proto.EncryptedMessage) (*proto.Message, error), handlerFn func(msg *proto.Message) error) *Worker {
return &Worker{
decryptMessage: decryptFn,
handler: handlerFn,
encryptedMsgPool: make(chan *proto.EncryptedMessage, 1),
}
}
func (w *Worker) AddMsg(ctx context.Context, msg *proto.EncryptedMessage) error {
// this is blocker because do not want to drop messages here
select {
case w.encryptedMsgPool <- msg:
case <-ctx.Done():
}
return nil
}
func (w *Worker) Work(ctx context.Context) {
for {
select {
case msg := <-w.encryptedMsgPool:
decryptedMessage, err := w.decryptMessage(msg)
if err != nil {
log.Errorf("failed to decrypt message: %v", err)
continue
}
if err := w.handler(decryptedMessage); err != nil {
log.Errorf("failed to handle message: %v", err)
continue
}
case <-ctx.Done():
log.Infof("Message worker stopping due to context cancellation")
return
}
}
}

View File

@@ -230,6 +230,7 @@ type Body struct {
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
// relayServerAddress is url of the relay server
RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"`
}
func (x *Body) Reset() {
@@ -320,6 +321,13 @@ func (x *Body) GetRelayServerAddress() string {
return ""
}
func (x *Body) GetSessionId() []byte {
if x != nil {
return x.SessionId
}
return nil
}
// Mode indicates a connection mode
type Mode struct {
state protoimpl.MessageState
@@ -443,7 +451,7 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xb3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xe4, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
@@ -466,34 +474,37 @@ var file_signalexchange_proto_rawDesc = []byte{
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x12, 0x0b,
0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x22, 0x2e, 0x0a, 0x04, 0x4d,
0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01,
0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52,
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28,
0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61,
0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65,
0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18,
0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a,
0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c,
0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43,
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e,
0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20,
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f,
0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x73,
0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70,
0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06,
0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44,
0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10,
0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x0c,
0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x2e, 0x0a, 0x04,
0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01,
0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01,
0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f,
0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b,
0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73,
0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72,
0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e,
0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c,
0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d,
0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e,
0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a,
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -601,6 +612,7 @@ func file_signalexchange_proto_init() {
}
}
}
file_signalexchange_proto_msgTypes[2].OneofWrappers = []interface{}{}
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
type x struct{}
out := protoimpl.TypeBuilder{

View File

@@ -64,6 +64,8 @@ message Body {
// relayServerAddress is url of the relay server
string relayServerAddress = 8;
optional bytes sessionId = 10;
}
// Mode indicates a connection mode