Compare commits

...

33 Commits

Author SHA1 Message Date
aliamerj
e953089f0b fix snyk hopefully 2025-08-29 16:13:00 +03:00
aliamerj
a9c179ee61 fix snyk 2025-08-29 16:01:31 +03:00
aliamerj
55873e38f7 fix code-gen 2025-08-29 15:28:04 +03:00
aliamerj
073990c10f update codegen 2025-08-29 15:00:15 +03:00
aliamerj
6c80e6adee fix 1 2025-08-29 14:23:04 +03:00
aliamerj
95dff6bbb2 update package 2025-08-29 14:19:22 +03:00
aliamerj
0b9c1fd0a3 running go mod tidy 2025-08-29 14:05:22 +03:00
aliamerj
e19cd11461 use RawJson for both parameters and results 2025-08-29 13:56:28 +03:00
aliamerj
2db19c27f2 fix sonar issue 2025-08-29 10:50:55 +03:00
aliamerj
373f978baa get rid of any/interface type in job database 2025-08-29 10:17:44 +03:00
aliamerj
d693c4f5a6 fix error handle in create job 2025-08-29 08:50:25 +03:00
aliamerj
6c0a46dfd5 apply feedback 2 2025-08-28 00:19:53 +03:00
aliamerj
9043938233 clean switch case 2025-08-27 13:00:43 +03:00
aliamerj
556c4c7777 fix api object 2025-08-27 12:46:20 +03:00
aliamerj
d1aec58108 fix lint 2025-08-26 13:09:54 +03:00
aliamerj
667dfdfcc3 change api and apply new schema 2025-08-26 13:00:28 +03:00
aliamerj
1fe6295d44 fix typo 2025-08-25 15:44:06 +03:00
aliamerj
cc1338f92d apply feedbacks 1 2025-08-22 20:57:50 +03:00
aliamerj
b7f0088fe3 fix MarkPendingJobsAsFailed 2025-08-21 15:04:11 +03:00
aliamerj
78bc18cb9e clean up 2025-08-20 22:19:37 +03:00
aliamerj
0befefe085 fix lint 2025-08-20 20:42:04 +03:00
aliamerj
f4757165c1 implement remote debug api 2025-08-20 20:15:46 +03:00
dependabot[bot]
9685411246 [misc] Bump golang.org/x/oauth2 from 0.24.0 to 0.27.0 (#4176)
Bumps [golang.org/x/oauth2](https://github.com/golang/oauth2) from 0.24.0 to 0.27.0
2025-08-19 16:26:46 +03:00
hakansa
d00a226556 [management] Add CreatedAt field to Peer and PeerBatch models (#4371)
[management] Add CreatedAt field to Peer and PeerBatch models (#4371)
2025-08-19 16:02:11 +03:00
Pascal Fischer
5d361b5421 [management] add nil handling for route domains (#4366) 2025-08-19 11:35:03 +02:00
dependabot[bot]
a889c4108b [misc] Bump github.com/containerd/containerd from 1.7.16 to 1.7.27 (#3527)
Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.16 to 1.7.27
2025-08-18 21:57:21 +03:00
Zoltan Papp
12cad854b2 [client] Fix/ice handshake (#4281)
In this PR, speed up the GRPC message processing, force the recreation of the ICE agent when getting a new, remote offer (do not wait for local STUN timeout).
2025-08-18 20:09:50 +02:00
Pascal Fischer
6a3846a8b7 [management] Remove save account calls (#4349) 2025-08-18 12:37:20 +02:00
Viktor Liu
7cd5dcae59 [client] Fix rule order for deny rules in peer ACLs (#4147) 2025-08-18 11:17:00 +02:00
Pascal Fischer
0e62325d46 [management] fail on geo location init failure (#4362) 2025-08-18 10:53:55 +02:00
Pascal Fischer
b3056d0937 [management] Use DI containers for server bootstrapping (#4343) 2025-08-15 17:14:48 +02:00
Zoltan Papp
ab853ac2a5 [server] Add MySQL initialization script and update Docker configuration (#4345) 2025-08-14 17:53:59 +02:00
Misha Bragin
e97f853909 Improve wording in the NetBird client app (#4316) 2025-08-13 22:03:48 +02:00
73 changed files with 2877 additions and 1012 deletions

View File

@@ -83,6 +83,15 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Setup MySQL privileges
if: matrix.store == 'mysql'
run: |
sleep 10
mysql -h 127.0.0.1 -u root -pmysqlroot -e "
GRANT SYSTEM_VARIABLES_ADMIN ON *.* TO 'netbird'@'%';
FLUSH PRIVILEGES;
"
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/

View File

@@ -33,7 +33,7 @@ var (
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
Use: "debug", Use: "debug",
Short: "Debugging commands", Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the NetBird daemon.", Long: "Commands for debugging and logging within the NetBird daemon.",
} }
var debugBundleCmd = &cobra.Command{ var debugBundleCmd = &cobra.Command{

View File

@@ -14,7 +14,8 @@ import (
var downCmd = &cobra.Command{ var downCmd = &cobra.Command{
Use: "down", Use: "down",
Short: "down netbird connections", Short: "Disconnect from the NetBird network",
Long: "Disconnect the NetBird client from the network and management service. This will terminate all active connections with the remote peers.",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)

View File

@@ -31,7 +31,8 @@ func init() {
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
Use: "login", Use: "login",
Short: "login to the NetBird Management Service (first run)", Short: "Log in to the NetBird network",
Long: "Log in to the NetBird network using a setup key or SSO",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setEnvAndFlags(cmd); err != nil { if err := setEnvAndFlags(cmd); err != nil {
return fmt.Errorf("set env and flags: %v", err) return fmt.Errorf("set env and flags: %v", err)

View File

@@ -14,7 +14,8 @@ import (
var logoutCmd = &cobra.Command{ var logoutCmd = &cobra.Command{
Use: "deregister", Use: "deregister",
Aliases: []string{"logout"}, Aliases: []string{"logout"},
Short: "deregister from the NetBird Management Service and delete peer", Short: "Deregister from the NetBird management service and delete this peer",
Long: "This command will deregister the current peer from the NetBird management service and all associated configuration. Use with caution as this will remove the peer from the network.",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)

View File

@@ -15,7 +15,7 @@ var appendFlag bool
var networksCMD = &cobra.Command{ var networksCMD = &cobra.Command{
Use: "networks", Use: "networks",
Aliases: []string{"routes"}, Aliases: []string{"routes"},
Short: "Manage networks", Short: "Manage connections to NetBird Networks and Resources",
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`, Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
} }

View File

@@ -16,13 +16,13 @@ import (
var profileCmd = &cobra.Command{ var profileCmd = &cobra.Command{
Use: "profile", Use: "profile",
Short: "manage NetBird profiles", Short: "Manage NetBird client profiles",
Long: `Manage NetBird profiles, allowing you to list, switch, and remove profiles.`, Long: `Commands to list, add, remove, and switch profiles. Profiles allow you to maintain different accounts in one client app.`,
} }
var profileListCmd = &cobra.Command{ var profileListCmd = &cobra.Command{
Use: "list", Use: "list",
Short: "list all profiles", Short: "List all profiles",
Long: `List all available profiles in the NetBird client.`, Long: `List all available profiles in the NetBird client.`,
Aliases: []string{"ls"}, Aliases: []string{"ls"},
RunE: listProfilesFunc, RunE: listProfilesFunc,
@@ -30,7 +30,7 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{ var profileAddCmd = &cobra.Command{
Use: "add <profile_name>", Use: "add <profile_name>",
Short: "add a new profile", Short: "Add a new profile",
Long: `Add a new profile to the NetBird client. The profile name must be unique.`, Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: addProfileFunc, RunE: addProfileFunc,
@@ -38,16 +38,16 @@ var profileAddCmd = &cobra.Command{
var profileRemoveCmd = &cobra.Command{ var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>", Use: "remove <profile_name>",
Short: "remove a profile", Short: "Remove a profile",
Long: `Remove a profile from the NetBird client. The profile must not be active.`, Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: removeProfileFunc, RunE: removeProfileFunc,
} }
var profileSelectCmd = &cobra.Command{ var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>", Use: "select <profile_name>",
Short: "select a profile", Short: "Select a profile",
Long: `Select a profile to be the active profile in the NetBird client. The profile must exist.`, Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: selectProfileFunc, RunE: selectProfileFunc,
} }

View File

@@ -19,7 +19,7 @@ import (
var serviceCmd = &cobra.Command{ var serviceCmd = &cobra.Command{
Use: "service", Use: "service",
Short: "manages NetBird service", Short: "Manage the NetBird daemon service",
} }
var ( var (

View File

@@ -107,7 +107,7 @@ func createServiceConfigForInstall() (*service.Config, error) {
var installCmd = &cobra.Command{ var installCmd = &cobra.Command{
Use: "install", Use: "install",
Short: "installs NetBird service", Short: "Install NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil { if err := setupServiceCommand(cmd); err != nil {
return err return err

View File

@@ -40,7 +40,7 @@ var sshCmd = &cobra.Command{
return nil return nil
}, },
Short: "connect to a remote SSH server", Short: "Connect to a remote SSH server",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd) SetFlagsFromEnvVars(cmd)

View File

@@ -32,7 +32,8 @@ var (
var statusCmd = &cobra.Command{ var statusCmd = &cobra.Command{
Use: "status", Use: "status",
Short: "status of the Netbird Service", Short: "Display NetBird client status",
Long: "Display the current status of the NetBird client, including connection status, peer information, and network details.",
RunE: statusFunc, RunE: statusFunc,
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -35,7 +36,7 @@ import (
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
t.Helper() t.Helper()
config := &types.Config{} config := &config.Config{}
_, err := util.ReadJson("../testdata/management.json", config) _, err := util.ReadJson("../testdata/management.json", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -70,7 +71,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *config.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")

View File

@@ -53,7 +53,8 @@ var (
upCmd = &cobra.Command{ upCmd = &cobra.Command{
Use: "up", Use: "up",
Short: "install, login and start NetBird client", Short: "Connect to the NetBird network",
Long: "Connect to the NetBird network using the provided setup key or SSO auth. This command will bring up the WireGuard interface, connect to the management server, and establish peer-to-peer connections with other peers in the network if required.",
RunE: upFunc, RunE: upFunc,
} }
) )

View File

@@ -9,7 +9,7 @@ import (
var ( var (
versionCmd = &cobra.Command{ versionCmd = &cobra.Command{
Use: "version", Use: "version",
Short: "prints NetBird version", Short: "Print the NetBird's client application version",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
cmd.Println(version.NetbirdVersion()) cmd.Println(version.NetbirdVersion())

View File

@@ -85,7 +85,7 @@ func (m *aclManager) AddPeerFiltering(
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
chain := chainNameInputRules chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort) ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs) mangleSpecs := slices.Clone(specs)
@@ -135,7 +135,14 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists") return nil, fmt.Errorf("rule already exists")
} }
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil { // Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// Insert at the beginning of the chain (position 1)
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
} else {
err = m.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
return nil, err return nil, err
} }
@@ -388,17 +395,25 @@ func actionToStr(action firewall.Action) string {
return "DROP" return "DROP"
} }
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string { func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
switch { if ipsetName == "" {
case ipsetName == "":
return "" return ""
}
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"
}
switch {
case sPort != nil && dPort != nil: case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport" return ipsetName + "-sport-dport" + actionSuffix
case sPort != nil: case sPort != nil:
return ipsetName + "-sport" return ipsetName + "-sport" + actionSuffix
case dPort != nil: case dPort != nil:
return ipsetName + "-dport" return ipsetName + "-dport" + actionSuffix
default: default:
return ipsetName return ipsetName + actionSuffix
} }
} }

View File

@@ -3,6 +3,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"strings"
"testing" "testing"
"time" "time"
@@ -15,7 +16,7 @@ import (
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
@@ -109,10 +110,84 @@ func TestIptablesManager(t *testing.T) {
}) })
} }
func TestIptablesManagerDenyRules(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
manager, err := Create(ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
defer func() {
err := manager.Close(nil)
require.NoError(t, err)
}()
t.Run("add deny rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{22}}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
require.NoError(t, err, "failed to add deny rule")
require.NotEmpty(t, rule, "deny rule should not be empty")
// Verify the rule was added by checking iptables
for _, r := range rule {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("deny rule precedence test", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.4")
port := &fw.Port{Values: []uint16{80}}
// Add accept rule first
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for same IP/port - this should take precedence
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
// Inspect the actual iptables rules to verify deny rule comes before accept rule
rules, err := ipv4Client.List("filter", chainNameInputRules)
require.NoError(t, err, "failed to list iptables rules")
// Debug: print all rules
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
for i, rule := range rules {
t.Logf(" [%d] %s", i, rule)
}
var denyRuleIndex, acceptRuleIndex int = -1, -1
for i, rule := range rules {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
denyRuleIndex = i
}
}
if strings.Contains(rule, "ACCEPT") {
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
acceptRuleIndex = i
}
}
}
require.NotEqual(t, -1, denyRuleIndex, "deny rule should exist in iptables")
require.NotEqual(t, -1, acceptRuleIndex, "accept rule should exist in iptables")
require.Less(t, denyRuleIndex, acceptRuleIndex,
"deny rule should come before accept rule in iptables chain (deny at index %d, accept at index %d)",
denyRuleIndex, acceptRuleIndex)
})
}
func TestIptablesManagerIPSet(t *testing.T) { func TestIptablesManagerIPSet(t *testing.T) {
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
@@ -176,7 +251,7 @@ func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName strin
func TestIptablesCreatePerformance(t *testing.T) { func TestIptablesCreatePerformance(t *testing.T) {
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{

View File

@@ -341,30 +341,38 @@ func (m *AclManager) addIOFiltering(
userData := []byte(ruleId) userData := []byte(ruleId)
chain := m.chainInputRules chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{ rule := &nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
Exprs: mainExpressions, Exprs: mainExpressions,
UserData: userData, UserData: userData,
}) }
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var nftRule *nftables.Rule
if action == firewall.ActionDrop {
nftRule = m.rConn.InsertRule(rule)
} else {
nftRule = m.rConn.AddRule(rule)
}
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err) return nil, fmt.Errorf(flushError, err)
} }
rule := &Rule{ ruleStruct := &Rule{
nftRule: nftRule, nftRule: nftRule,
mangleRule: m.createPreroutingRule(expressions, userData), mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset, nftSet: ipset,
ruleID: ruleId, ruleID: ruleId,
ip: ip, ip: ip,
} }
m.rules[ruleId] = rule m.rules[ruleId] = ruleStruct
if ipset != nil { if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name) m.ipsetStore.AddReferenceToIpset(ipset.Name)
} }
return rule, nil return ruleStruct, nil
} }
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule { func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {

View File

@@ -2,6 +2,7 @@ package nftables
import ( import (
"bytes" "bytes"
"encoding/binary"
"fmt" "fmt"
"net/netip" "net/netip"
"os/exec" "os/exec"
@@ -20,7 +21,7 @@ import (
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
@@ -103,9 +104,8 @@ func TestNftablesManager(t *testing.T) {
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
}, },
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) // Since DROP rules are inserted at position 0, the DROP rule comes first
expectedDropExprs := []expr.Any{
expectedExprs2 := []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@@ -141,7 +141,12 @@ func TestNftablesManager(t *testing.T) {
}, },
&expr.Verdict{Kind: expr.VerdictDrop}, &expr.Verdict{Kind: expr.VerdictDrop},
} }
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
// Compare DROP rule at position 0 (inserted first due to InsertRule)
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedDropExprs)
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
for _, r := range rule { for _, r := range rule {
err = manager.DeletePeerRule(r) err = manager.DeletePeerRule(r)
@@ -160,10 +165,90 @@ func TestNftablesManager(t *testing.T) {
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
func TestNftablesManagerRuleOrder(t *testing.T) {
// This test verifies rule insertion order in nftables peer ACLs
// We add accept rule first, then deny rule to test ordering behavior
manager, err := Create(ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
defer func() {
err = manager.Close(nil)
require.NoError(t, err)
}()
ip := netip.MustParseAddr("100.96.0.2").Unmap()
testClient := &nftables.Conn{}
// Add accept rule first
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for the same traffic
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex int = -1, -1
for i, rule := range rules {
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
hasPort80 := false
var action string
for _, e := range rule.Exprs {
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
if lookup.SetName == "accept-http" {
hasAcceptHTTPSet = true
} else if lookup.SetName == "deny-http" {
hasDenyHTTPSet = true
}
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
hasPort80 = true
}
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
if verdict.Kind == expr.VerdictAccept {
action = "ACCEPT"
} else if verdict.Kind == expr.VerdictDrop {
action = "DROP"
}
}
}
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
acceptRuleIndex = i
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
denyRuleIndex = i
}
}
require.NotEqual(t, -1, acceptRuleIndex, "accept rule should exist in nftables")
require.NotEqual(t, -1, denyRuleIndex, "deny rule should exist in nftables")
require.Less(t, denyRuleIndex, acceptRuleIndex,
"deny rule should come before accept rule in nftables chain (deny at index %d, accept at index %d)",
denyRuleIndex, acceptRuleIndex)
}
func TestNFtablesCreatePerformance(t *testing.T) { func TestNFtablesCreatePerformance(t *testing.T) {
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{

View File

@@ -18,6 +18,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {

View File

@@ -27,6 +27,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {

View File

@@ -70,14 +70,13 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
// outgoingRules is used for hooks only outgoingRules map[netip.Addr]RuleSet
outgoingRules map[netip.Addr]RuleSet incomingDenyRules map[netip.Addr]RuleSet
// incomingRules is used for filtering and hooks incomingRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet routeRules RouteRules
routeRules RouteRules decoders sync.Pool
decoders sync.Pool wgIface common.IFaceMapper
wgIface common.IFaceMapper nativeFirewall firewall.Manager
nativeFirewall firewall.Manager
mutex sync.RWMutex mutex sync.RWMutex
@@ -186,6 +185,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}, },
nativeFirewall: nativeFirewall, nativeFirewall: nativeFirewall,
outgoingRules: make(map[netip.Addr]RuleSet), outgoingRules: make(map[netip.Addr]RuleSet),
incomingDenyRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet), incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface, wgIface: iface,
localipmanager: newLocalIPManager(), localipmanager: newLocalIPManager(),
@@ -417,10 +417,17 @@ func (m *Manager) AddPeerFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if _, ok := m.incomingRules[r.ip]; !ok { var targetMap map[netip.Addr]RuleSet
m.incomingRules[r.ip] = make(RuleSet) if r.drop {
targetMap = m.incomingDenyRules
} else {
targetMap = m.incomingRules
} }
m.incomingRules[r.ip][r.id] = r
if _, ok := targetMap[r.ip]; !ok {
targetMap[r.ip] = make(RuleSet)
}
targetMap[r.ip][r.id] = r
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
@@ -507,10 +514,24 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if _, ok := m.incomingRules[r.ip][r.id]; !ok { var sourceMap map[netip.Addr]RuleSet
if r.drop {
sourceMap = m.incomingDenyRules
} else {
sourceMap = m.incomingRules
}
if ruleset, ok := sourceMap[r.ip]; ok {
if _, exists := ruleset[r.id]; !exists {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(ruleset, r.id)
if len(ruleset) == 0 {
delete(sourceMap, r.ip)
}
} else {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id) return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
} }
delete(m.incomingRules[r.ip], r.id)
return nil return nil
} }
@@ -572,7 +593,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil return nil
} }
// FilterOutBound filters outgoing packets // FilterOutbound filters outgoing packets
func (m *Manager) FilterOutbound(packetData []byte, size int) bool { func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
return m.filterOutbound(packetData, size) return m.filterOutbound(packetData, size)
} }
@@ -761,7 +782,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked { if blocked {
_, pnum := getProtocolFromPacket(d) _, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
@@ -971,26 +992,28 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) { func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
if m.isSpecialICMP(d) { if m.isSpecialICMP(d) {
return nil, false return nil, false
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
return mgmtId, filter return mgmtId, filter
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter return mgmtId, filter
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter
}
// Default policy: DROP ALL
return nil, true return nil, true
} }
@@ -1013,6 +1036,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) { func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1] payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP && ip.Compare(rule.ip) != 0 { if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue continue
@@ -1045,6 +1069,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true return rule.mgmtId, rule.drop, true
} }
} }
return nil, false, false return nil, false, false
} }
@@ -1116,6 +1141,7 @@ func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook fu
m.mutex.Lock() m.mutex.Lock()
if in { if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule) m.incomingRules[r.ip] = make(map[string]PeerRule)
} }
@@ -1136,6 +1162,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules { for _, arr := range m.incomingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
@@ -1144,6 +1171,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
} }
} }
} }
// Check outgoing hooks
for _, arr := range m.outgoingRules { for _, arr := range m.outgoingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {

View File

@@ -458,6 +458,31 @@ func TestPeerACLFiltering(t *testing.T) {
ruleAction: fw.ActionDrop, ruleAction: fw.ActionDrop,
shouldBeBlocked: true, shouldBeBlocked: true,
}, },
{
name: "Peer ACL - Drop rule should override accept all rule",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 22,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{22}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Peer ACL - Drop all traffic from specific IP",
srcIP: "100.10.0.99",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.99",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
} }
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
@@ -468,13 +493,11 @@ func TestPeerACLFiltering(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop { if tc.ruleAction == fw.ActionDrop {
// add general accept rule to test drop rule // add general accept rule for the same IP to test drop rule precedence
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
rules, err := manager.AddPeerFiltering( rules, err := manager.AddPeerFiltering(
nil, nil,
net.ParseIP("0.0.0.0"), net.ParseIP(tc.ruleIP),
fw.ProtocolALL, fw.ProtocolALL,
nil, nil,
nil, nil,

View File

@@ -136,9 +136,22 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
// Check rules exist in appropriate maps
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; !ok { peerRule, ok := r.(*PeerRule)
t.Errorf("rule2 is not in the incomingRules") if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule exists in deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if !found {
t.Errorf("rule2 is not in the expected rules map")
} }
} }
@@ -150,9 +163,22 @@ func TestManagerDeleteRule(t *testing.T) {
} }
} }
// Check rules are removed from appropriate maps
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; ok { peerRule, ok := r.(*PeerRule)
t.Errorf("rule2 is not in the incomingRules") if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule is removed from deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if found {
t.Errorf("rule2 should be removed from the rules map")
} }
} }
} }
@@ -196,16 +222,17 @@ func TestAddUDPPacketHook(t *testing.T) {
var addedRule PeerRule var addedRule PeerRule
if tt.in { if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 { if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return return
} }
for _, rule := range manager.incomingRules[tt.ip] { for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} else { } else {
if len(manager.outgoingRules) != 1 { if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return return
} }
for _, rule := range manager.outgoingRules[tt.ip] { for _, rule := range manager.outgoingRules[tt.ip] {
@@ -261,8 +288,8 @@ func TestManagerReset(t *testing.T) {
return return
} }
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 { if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
t.Errorf("rules is not empty") t.Errorf("rules are not empty")
} }
} }

View File

@@ -314,7 +314,7 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
trace.AddResult(StageRouting, "Packet destined for local delivery", true) trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) ruleId, blocked := m.peerACLsBlock(srcIP, d, packetData)
strRuleId := "<no id>" strRuleId := "<no id>"
if ruleId != nil { if ruleId != nil {

View File

@@ -55,11 +55,11 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client" mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto" sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -254,6 +254,7 @@ func NewEngine(
} }
engine.stateManager = statemanager.New(path) engine.stateManager = statemanager.New(path)
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine return engine
} }
@@ -1330,52 +1331,17 @@ func (e *Engine) receiveSignalEvents() {
} }
switch msg.GetBody().Type { switch msg.GetBody().Type {
case sProto.Body_OFFER: case sProto.Body_OFFER, sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg) offerAnswer, err := convertToOfferAnswer(msg)
if err != nil { if err != nil {
return err return err
} }
var rosenpassPubKey []byte if msg.Body.Type == sProto.Body_OFFER {
rosenpassAddr := "" conn.OnRemoteOffer(*offerAnswer)
if msg.GetBody().GetRosenpassConfig() != nil { } else {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey() conn.OnRemoteAnswer(*offerAnswer)
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
} }
conn.OnRemoteOffer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return err
}
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
conn.OnRemoteAnswer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_CANDIDATE: case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload) candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
if err != nil { if err != nil {
@@ -2073,3 +2039,44 @@ func createFile(path string) error {
} }
return file.Close() return file.Close()
} }
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return nil, err
}
var (
rosenpassPubKey []byte
rosenpassAddr string
)
if cfg := msg.GetBody().GetRosenpassConfig(); cfg != nil {
rosenpassPubKey = cfg.GetRosenpassPubKey()
rosenpassAddr = cfg.GetRosenpassServerAddr()
}
// Handle optional SessionID
var sessionID *peer.ICESessionID
if sessionBytes := msg.GetBody().GetSessionId(); sessionBytes != nil {
if id, err := peer.ICESessionIDFromBytes(sessionBytes); err != nil {
log.Warnf("Invalid session ID in message: %v", err)
sessionID = nil // Set to nil if conversion fails
} else {
sessionID = &id
}
}
offerAnswer := peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
SessionID: sessionID,
}
return &offerAnswer, nil
}

View File

@@ -27,6 +27,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
@@ -44,8 +45,6 @@ import (
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -55,8 +54,10 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
@@ -1514,15 +1515,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) { func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper() t.Helper()
config := &types.Config{ config := &config.Config{
Stuns: []*types.Host{}, Stuns: []*config.Host{},
TURNConfig: &types.TURNConfig{}, TURNConfig: &config.TURNConfig{},
Relay: &types.Relay{ Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"}, Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour}, CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222", Secret: "222222222222222222",
}, },
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: "localhost:10000", URI: "localhost:10000",
}, },

View File

@@ -24,8 +24,8 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker" "github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
@@ -200,19 +200,11 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.wg.Add(1) conn.wg.Add(1)
go func() { go func() {
defer conn.wg.Done() defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx) conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx) conn.semaphore.Done(conn.ctx)
conn.dumpState.SendOffer() conn.guard.Start(conn.ctx, conn.onGuardEvent)
if err := conn.handshaker.sendOffer(); err != nil {
conn.Log.Errorf("failed to send initial offer: %v", err)
}
conn.wg.Add(1)
go func() {
conn.guard.Start(conn.ctx, conn.onGuardEvent)
conn.wg.Done()
}()
}() }()
conn.opened = true conn.opened = true
return nil return nil
@@ -274,10 +266,10 @@ func (conn *Conn) Close(signalToRemote bool) {
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
conn.dumpState.RemoteAnswer() conn.dumpState.RemoteAnswer()
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer) conn.handshaker.OnRemoteAnswer(answer)
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -296,10 +288,10 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler conn.onDisconnected = handler
} }
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
conn.dumpState.RemoteOffer() conn.dumpState.RemoteOffer()
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer) conn.handshaker.OnRemoteOffer(offer)
} }
// WgConfig returns the WireGuard config // WgConfig returns the WireGuard config
@@ -548,7 +540,6 @@ func (conn *Conn) onRelayDisconnected() {
} }
func (conn *Conn) onGuardEvent() { func (conn *Conn) onGuardEvent() {
conn.Log.Debugf("send offer to peer")
conn.dumpState.SendOffer() conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil { if err := conn.handshaker.SendOffer(); err != nil {
conn.Log.Errorf("failed to send offer: %v", err) conn.Log.Errorf("failed to send offer: %v", err)
@@ -672,7 +663,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
} }
}() }()
if conn.statusICE.Get() == worker.StatusDisconnected { if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false return false
} }

View File

@@ -1,9 +1,9 @@
package peer package peer
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync"
"testing" "testing"
"time" "time"
@@ -79,31 +79,30 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return return
} }
wg := sync.WaitGroup{} onNewOffeChan := make(chan struct{})
wg.Add(2)
go func() {
<-conn.handshaker.remoteOffersCh
wg.Done()
}()
go func() { conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
for { onNewOffeChan <- struct{}{}
accepted := conn.OnRemoteOffer(OfferAnswer{ })
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
if accepted {
wg.Done()
return
}
}
}()
wg.Wait() conn.OnRemoteOffer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
}
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
@@ -119,31 +118,29 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return return
} }
wg := sync.WaitGroup{} onNewOffeChan := make(chan struct{})
wg.Add(2)
go func() {
<-conn.handshaker.remoteAnswerCh
wg.Done()
}()
go func() { conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
for { onNewOffeChan <- struct{}{}
accepted := conn.OnRemoteAnswer(OfferAnswer{ })
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
if accepted {
wg.Done()
return
}
}
}()
wg.Wait() conn.OnRemoteAnswer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
}
} }
func TestConn_presharedKey(t *testing.T) { func TestConn_presharedKey(t *testing.T) {

View File

@@ -19,7 +19,6 @@ type isConnectedFunc func() bool
// - Relayed connection disconnected // - Relayed connection disconnected
// - ICE candidate changes // - ICE candidate changes
type Guard struct { type Guard struct {
Reconnect chan struct{}
log *log.Entry log *log.Entry
isConnectedOnAllWay isConnectedFunc isConnectedOnAllWay isConnectedFunc
timeout time.Duration timeout time.Duration
@@ -30,7 +29,6 @@ type Guard struct {
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{ return &Guard{
Reconnect: make(chan struct{}, 1),
log: log, log: log,
isConnectedOnAllWay: isConnectedFn, isConnectedOnAllWay: isConnectedFn,
timeout: timeout, timeout: timeout,
@@ -41,6 +39,7 @@ func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Durati
} }
func (g *Guard) Start(ctx context.Context, eventCallback func()) { func (g *Guard) Start(ctx context.Context, eventCallback func()) {
g.log.Infof("starting guard for reconnection with MaxInterval: %s", g.timeout)
g.reconnectLoopWithRetry(ctx, eventCallback) g.reconnectLoopWithRetry(ctx, eventCallback)
} }
@@ -61,17 +60,14 @@ func (g *Guard) SetICEConnDisconnected() {
// reconnectLoopWithRetry periodically check the connection status. // reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported // Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener() srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan) defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.prepareExponentTicker(ctx) ticker := g.initialTicker(ctx)
defer ticker.Stop() defer ticker.Stop()
tickerChannel := ticker.C tickerChannel := ticker.C
g.log.Infof("start reconnect loop...")
for { for {
select { select {
case t := <-tickerChannel: case t := <-tickerChannel:
@@ -85,7 +81,6 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
if !g.isConnectedOnAllWay() { if !g.isConnectedOnAllWay() {
callback() callback()
} }
case <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker") g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
@@ -111,6 +106,20 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
} }
} }
// initialTicker give chance to the peer to establish the initial connection.
func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 3 * time.Second,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
return backoff.NewTicker(bo)
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
@@ -126,13 +135,3 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
return ticker return ticker
} }
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(3 * time.Second):
}
}

View File

@@ -39,6 +39,15 @@ type OfferAnswer struct {
// relay server address // relay server address
RelaySrvAddress string RelaySrvAddress string
// SessionID is the unique identifier of the session, used to discard old messages
SessionID *ICESessionID
}
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
} }
type Handshaker struct { type Handshaker struct {
@@ -74,21 +83,25 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
for { for {
h.log.Info("wait for remote offer confirmation") select {
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx) case remoteOfferAnswer := <-h.remoteOffersCh:
if err != nil { // received confirmation from the remote peer -> ready to proceed
var connectionClosedError *ConnectionClosedError if err := h.sendAnswer(); err != nil {
if errors.As(err, &connectionClosedError) { h.log.Errorf("failed to send remote offer confirmation: %s", err)
h.log.Info("exit from handshaker") continue
return
} }
h.log.Errorf("failed to received remote offer confirmation: %s", err) for _, listener := range h.onNewOfferListeners {
continue listener(&remoteOfferAnswer)
} }
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
h.log.Infof("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort) case remoteOfferAnswer := <-h.remoteAnswerCh:
for _, listener := range h.onNewOfferListeners { h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
go listener(remoteOfferAnswer) for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer)
}
case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers")
return
} }
} }
} }
@@ -101,43 +114,27 @@ func (h *Handshaker) SendOffer() error {
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool { func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) {
select { select {
case h.remoteOffersCh <- offer: case h.remoteOffersCh <- offer:
return true return
default: default:
h.log.Warnf("OnRemoteOffer skipping message because is not ready") h.log.Warnf("skipping remote offer message because receiver not ready")
// connection might not be ready yet to receive so we ignore the message // connection might not be ready yet to receive so we ignore the message
return false return
} }
} }
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool { func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) {
select { select {
case h.remoteAnswerCh <- answer: case h.remoteAnswerCh <- answer:
return true return
default: default:
// connection might not be ready yet to receive so we ignore the message // connection might not be ready yet to receive so we ignore the message
h.log.Debugf("OnRemoteAnswer skipping message because is not ready") h.log.Warnf("skipping remote answer message because receiver not ready")
return false return
}
}
func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
if err := h.sendAnswer(); err != nil {
return nil, err
}
return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil
case <-ctx.Done():
// closed externally
return nil, NewConnectionClosedError(h.config.Key)
} }
} }
@@ -147,43 +144,34 @@ func (h *Handshaker) sendOffer() error {
return ErrSignalIsNotReady return ErrSignalIsNotReady
} }
iceUFrag, icePwd := h.ice.GetLocalUserCredentials() offer := h.buildOfferAnswer()
offer := OfferAnswer{ h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
IceCredentials: IceCredentials{iceUFrag, icePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
offer.RelaySrvAddress = addr
}
return h.signaler.SignalOffer(offer, h.config.Key) return h.signaler.SignalOffer(offer, h.config.Key)
} }
func (h *Handshaker) sendAnswer() error { func (h *Handshaker) sendAnswer() error {
h.log.Infof("sending answer") answer := h.buildOfferAnswer()
uFrag, pwd := h.ice.GetLocalUserCredentials() h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
return h.signaler.SignalAnswer(answer, h.config.Key)
}
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{ answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd}, IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort, WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(), Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr, RosenpassAddr: h.config.RosenpassConfig.Addr,
SessionID: &sid,
} }
addr, err := h.relay.RelayInstanceAddress()
if err == nil { if addr, err := h.relay.RelayInstanceAddress(); err == nil {
answer.RelaySrvAddress = addr answer.RelaySrvAddress = addr
} }
err = h.signaler.SignalAnswer(answer, h.config.Key) return answer
if err != nil {
return err
}
return nil
} }

View File

@@ -0,0 +1,47 @@
package peer
import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
)
const sessionIDSize = 5
type ICESessionID string
// NewICESessionID generates a new session ID for distinguishing sessions
func NewICESessionID() (ICESessionID, error) {
b := make([]byte, sessionIDSize)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
return ICESessionID(hex.EncodeToString(b)), nil
}
func ICESessionIDFromBytes(b []byte) (ICESessionID, error) {
if len(b) != sessionIDSize {
return "", fmt.Errorf("invalid session ID length: %d", len(b))
}
return ICESessionID(hex.EncodeToString(b)), nil
}
// Bytes returns the raw bytes of the session ID for protobuf serialization
func (id ICESessionID) Bytes() ([]byte, error) {
if len(id) == 0 {
return nil, fmt.Errorf("ICE session ID is empty")
}
b, err := hex.DecodeString(string(id))
if err != nil {
return nil, fmt.Errorf("invalid ICE session ID encoding: %w", err)
}
if len(b) != sessionIDSize {
return nil, fmt.Errorf("invalid ICE session ID length: expected %d bytes, got %d", sessionIDSize, len(b))
}
return b, nil
}
func (id ICESessionID) String() string {
return string(id)
}

View File

@@ -2,6 +2,7 @@ package peer
import ( import (
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
@@ -45,6 +46,10 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer // SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
msg, err := signal.MarshalCredential( msg, err := signal.MarshalCredential(
s.wgPrivateKey, s.wgPrivateKey,
offerAnswer.WgListenPort, offerAnswer.WgListenPort,
@@ -56,13 +61,13 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
bodyType, bodyType,
offerAnswer.RosenpassPubKey, offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr, offerAnswer.RosenpassAddr,
offerAnswer.RelaySrvAddress) offerAnswer.RelaySrvAddress,
sessionIDBytes)
if err != nil { if err != nil {
return err return err
} }
err = s.signal.Send(msg) if err = s.signal.Send(msg); err != nil {
if err != nil {
return err return err
} }

View File

@@ -42,8 +42,18 @@ type WorkerICE struct {
statusRecorder *Status statusRecorder *Status
hasRelayOnLocally bool hasRelayOnLocally bool
agent *ice.Agent agent *ice.Agent
muxAgent sync.Mutex agentDialerCancel context.CancelFunc
agentConnecting bool // while it is true, drop all incoming offers
lastSuccess time.Time // with this avoid the too frequent ICE agent recreation
// remoteSessionID represents the peer's session identifier from the latest remote offer.
remoteSessionID ICESessionID
// sessionID is used to track the current session ID of the ICE agent
// increase by one when disconnecting the agent
// with it the remote peer can discard the already deprecated offer/answer
// Without it the remote peer may recreate a workable ICE connection
sessionID ICESessionID
muxAgent sync.Mutex
StunTurn []*stun.URI StunTurn []*stun.URI
@@ -57,6 +67,11 @@ type WorkerICE struct {
} }
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
sessionID, err := NewICESessionID()
if err != nil {
return nil, err
}
w := &WorkerICE{ w := &WorkerICE{
ctx: ctx, ctx: ctx,
log: log, log: log,
@@ -67,6 +82,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally, hasRelayOnLocally: hasRelayOnLocally,
lastKnownState: ice.ConnectionStateDisconnected, lastKnownState: ice.ConnectionStateDisconnected,
sessionID: sessionID,
} }
localUfrag, localPwd, err := icemaker.GenerateICECredentials() localUfrag, localPwd, err := icemaker.GenerateICECredentials()
@@ -79,15 +95,35 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
} }
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE") w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock() w.muxAgent.Lock()
if w.agent != nil { if w.agentConnecting {
w.log.Debugf("agent already exists, skipping the offer") w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock() w.muxAgent.Unlock()
return return
} }
if w.agent != nil {
// backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return
}
w.log.Debugf("agent already exists, recreate the connection")
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
// todo consider to switch to Relay connection while establishing a new ICE connection
}
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
preferredCandidateTypes = icemaker.CandidateTypesP2P() preferredCandidateTypes = icemaker.CandidateTypesP2P()
@@ -96,36 +132,124 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
} }
w.log.Debugf("recreate ICE agent") w.log.Debugf("recreate ICE agent")
agentCtx, agentCancel := context.WithCancel(w.ctx) dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil { if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err) w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock() w.muxAgent.Unlock()
return return
} }
w.sentExtraSrflx = false
w.agent = agent w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
w.muxAgent.Unlock() w.muxAgent.Unlock()
w.log.Debugf("gather candidates") go w.connect(dialerCtx, agent, remoteOfferAnswer)
err = w.agent.GatherCandidates() }
if err != nil {
w.log.Debugf("failed to gather candidates: %s", err) // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
if w.agent == nil {
w.log.Warnf("ICE Agent is not initialized yet")
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) InProgress() bool {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.agentConnecting
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) {
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
if err := agent.OnCandidate(w.onICECandidate); err != nil {
return nil, err
}
if err := agent.OnConnectionStateChange(w.onConnectionStateChange(agent, dialerCancel)); err != nil {
return nil, err
}
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
return nil, err
}
if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) SessionID() ICESessionID {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.sessionID
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("gather candidates")
if err := agent.GatherCandidates(); err != nil {
w.log.Warnf("failed to gather candidates: %s", err)
w.closeAgent(agent, w.agentDialerCancel)
return return
} }
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
w.log.Debugf("turn agent dial") w.log.Debugf("turn agent dial")
remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer) remoteConn, err := w.turnAgentDial(ctx, remoteOfferAnswer)
if err != nil { if err != nil {
w.log.Debugf("failed to dial the remote peer: %s", err) w.log.Debugf("failed to dial the remote peer: %s", err)
w.closeAgent(agent, w.agentDialerCancel)
return return
} }
w.log.Debugf("agent dial succeeded") w.log.Debugf("agent dial succeeded")
pair, err := w.agent.GetSelectedCandidatePair() pair, err := agent.GetSelectedCandidatePair()
if err != nil { if err != nil {
w.closeAgent(agent, w.agentDialerCancel)
return return
} }
@@ -152,114 +276,38 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RelayedOnLocal: isRelayCandidate(pair.Local), RelayedOnLocal: isRelayCandidate(pair.Local),
} }
w.log.Debugf("on ICE conn is ready to use") w.log.Debugf("on ICE conn is ready to use")
go w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. w.log.Infof("connection succeeded with offer session: %s", remoteOfferAnswer.SessionIDString())
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock() w.muxAgent.Lock()
defer w.muxAgent.Unlock() w.agentConnecting = false
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String()) w.lastSuccess = time.Now()
if w.agent == nil { if remoteOfferAnswer.SessionID != nil {
w.log.Warnf("ICE Agent is not initialized yet") w.remoteSessionID = *remoteOfferAnswer.SessionID
return
} }
w.muxAgent.Unlock()
if candidateViaRoutes(candidate, haRoutes) { // todo: the potential problem is a race between the onConnectionStateChange
return w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
err := w.agent.AddRemoteCandidate(candidate)
if err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
} }
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { func (w *WorkerICE) closeAgent(agent *ice.Agent, cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) {
w.sentExtraSrflx = false
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
err = agent.OnCandidate(w.onICECandidate)
if err != nil {
return nil, err
}
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agentCancel)
default:
return
}
})
if err != nil {
return nil, err
}
err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair)
if err != nil {
return nil, err
}
err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency())
if err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
})
if err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
cancel() cancel()
if w.agent == nil { if err := agent.Close(); err != nil {
return
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
w.agent = nil
w.muxAgent.Lock()
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent {
w.agent = nil
w.agentConnecting = false
}
w.muxAgent.Unlock()
} }
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -331,6 +379,32 @@ func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidat
w.config.Key) w.config.Key)
} }
func (w *WorkerICE) onConnectionStateChange(agent *ice.Agent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
return func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agent, dialerCancel)
default:
return
}
}
}
func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) {
if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true return true

View File

@@ -14,6 +14,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -32,7 +33,6 @@ import (
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
@@ -267,10 +267,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Helper() t.Helper()
dataDir := t.TempDir() dataDir := t.TempDir()
config := &types.Config{ config := &config.Config{
Stuns: []*types.Host{}, Stuns: []*config.Host{},
TURNConfig: &types.TURNConfig{}, TURNConfig: &config.TURNConfig{},
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: signalAddr, URI: signalAddr,
}, },

7
go.mod
View File

@@ -65,6 +65,7 @@ require (
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -102,7 +103,7 @@ require (
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/net v0.39.0 golang.org/x/net v0.39.0
golang.org/x/oauth2 v0.24.0 golang.org/x/oauth2 v0.27.0
golang.org/x/sync v0.13.0 golang.org/x/sync v0.13.0
golang.org/x/term v0.31.0 golang.org/x/term v0.31.0
google.golang.org/api v0.177.0 google.golang.org/api v0.177.0
@@ -125,6 +126,7 @@ require (
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
@@ -144,7 +146,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.26 // indirect github.com/containerd/containerd v1.7.27 // indirect
github.com/containerd/log v0.1.0 // indirect github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect
@@ -220,6 +222,7 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/rymdport/portal v0.3.0 // indirect github.com/rymdport/portal v0.3.0 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect

20
go.sum
View File

@@ -66,11 +66,14 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0= github.com/Microsoft/hcsshim v0.12.3 h1:LS9NXqXhMoqNCplK1ApmVSfB4UnVLRDWRapB6EIlxE0=
github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ= github.com/Microsoft/hcsshim v0.12.3/go.mod h1:Iyl1WVpZzr+UkzjekHZbV8o5Z9ZkxNGx6CtY2Qg/JVQ=
github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo= github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I= github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ=
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
@@ -116,6 +119,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM= github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -142,8 +146,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.26 h1:3cs8K2RHlMQaPifLqgRyI4VBkoldNdEw62cb7qQga7k= github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
github.com/containerd/containerd v1.7.26/go.mod h1:m4JU0E+h0ebbo9yXD7Hyt+sWnc8tChm7MudCjj4jRvQ= github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
@@ -416,6 +420,7 @@ github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/X
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk= github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk=
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw= github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@@ -516,6 +521,8 @@ github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7g
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/oapi-codegen/runtime v1.1.2 h1:P2+CubHq8fO4Q6fV1tqDBZHCwpVpvPg7oKiYzQgXIyI=
github.com/oapi-codegen/runtime v1.1.2/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg=
github.com/okta/okta-sdk-golang/v2 v2.18.0 h1:cfDasMb7CShbZvOrF6n+DnLevWwiHgedWMGJ8M8xKDc= github.com/okta/okta-sdk-golang/v2 v2.18.0 h1:cfDasMb7CShbZvOrF6n+DnLevWwiHgedWMGJ8M8xKDc=
github.com/okta/okta-sdk-golang/v2 v2.18.0/go.mod h1:dz30v3ctAiMb7jpsCngGfQUAEGm1/NsWT92uTbNDQIs= github.com/okta/okta-sdk-golang/v2 v2.18.0/go.mod h1:dz30v3ctAiMb7jpsCngGfQUAEGm1/NsWT92uTbNDQIs=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@@ -588,8 +595,8 @@ github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM=
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
@@ -627,6 +634,7 @@ github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.8.1/go.mod h1:o0Pch8wJ9BVSWGQMbra6iw0oQ5oktSIBaujf1rJH9Ns= github.com/spf13/viper v1.8.1/go.mod h1:o0Pch8wJ9BVSWGQMbra6iw0oQ5oktSIBaujf1rJH9Ns=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q= github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ= github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
@@ -868,8 +876,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M=
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -2,88 +2,40 @@ package cmd
import ( import (
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"net"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"os" "os"
"os/signal"
"path" "path"
"slices"
"strings" "strings"
"time" "syscall"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/internals/server"
"github.com/netbirdio/netbird/management/server" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
) )
// ManagementLegacyPort is the port that was used before by the Management gRPC server. var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
// It is used for backward compatibility now. return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
const ManagementLegacyPort = 33073 }
func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server) {
newServer = fn
}
var ( var (
mgmtPort int config *nbconfig.Config
mgmtMetricsPort int
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
config *types.Config
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
mgmtCmd = &cobra.Command{ mgmtCmd = &cobra.Command{
Use: "management", Use: "management",
@@ -102,9 +54,9 @@ var (
// detect whether user specified a port // detect whether user specified a port
userPort := cmd.Flag("port").Changed userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, types.MgmtConfigPath) config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil { if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", types.MgmtConfigPath, err) return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
} }
if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed { if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
@@ -151,356 +103,38 @@ var (
return fmt.Errorf("failed creating datadir: %s: %v", config.Datadir, err) return fmt.Errorf("failed creating datadir: %s: %v", config.Datadir, err)
} }
} }
appMetrics, err := telemetry.NewDefaultAppMetrics(cmd.Context())
if err != nil {
return err
}
err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
integrationMetrics, err := integrations.InitIntegrationMetrics(ctx, appMetrics)
if err != nil {
return err
}
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false)
if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager(appMetrics)
var idpManager idp.Manager
if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics)
if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
}
}
if disableSingleAccMode { if disableSingleAccMode {
mgmtSingleAccModeDomain = "" mgmtSingleAccModeDomain = ""
} }
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
return fmt.Errorf("initialize database: %s", err)
}
if config.DataStoreEncryptionKey != key { srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
log.WithContext(ctx).Infof("update config with activity store key") go func() {
config.DataStoreEncryptionKey = key if err := srv.Start(cmd.Context()); err != nil {
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config) log.Fatalf("Server error: %v", err)
}
}()
stopChan := make(chan os.Signal, 1)
signal.Notify(stopChan, os.Interrupt, syscall.SIGTERM)
select {
case <-stopChan:
log.Info("Received shutdown signal, stopping server...")
err = srv.Stop()
if err != nil { if err != nil {
return fmt.Errorf("write out store encryption key: %s", err) log.Errorf("Failed to stop server gracefully: %v", err)
} }
case err := <-srv.Errors():
log.Fatalf("Server stopped unexpectedly: %v", err)
} }
geo, err := geolocation.NewGeolocation(ctx, config.Datadir, !disableGeoliteUpdate)
if err != nil {
log.WithContext(ctx).Warnf("could not initialize geolocation service. proceeding without geolocation support: %v", err)
} else {
log.WithContext(ctx).Infof("geolocation service has been initialized from %s", config.Datadir)
}
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
if err != nil {
return fmt.Errorf("initialize integrated peer validator: %v", err)
}
permissionsManager := integrations.InitPermissionsManager(store)
userManager := users.NewManager(store)
extraSettingsManager := integrations.NewManager(eventStore)
settingsManager := settings.NewManager(store, userManager, extraSettingsManager, permissionsManager)
peersManager := peers.NewManager(store, permissionsManager)
proxyController := integrations.NewController(store)
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
if err != nil {
return fmt.Errorf("build default manager: %v", err)
}
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager, groupsManager)
trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
}
realipOpts := []realip.Option{
realip.WithTrustedPeers(trustedPeers),
realip.WithTrustedProxies(trustedHTTPProxies),
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
var certManager *autocert.Manager
var tlsConfig *tls.Config
tlsEnabled := false
if config.HttpConfig.LetsEncryptDomain != "" {
certManager, err = encryption.CreateCertManager(config.Datadir, config.HttpConfig.LetsEncryptDomain)
if err != nil {
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
tlsEnabled = true
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
if err != nil {
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
transportCredentials := credentials.NewTLS(tlsConfig)
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
tlsEnabled = true
}
authManager := auth.NewManager(store,
config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.AuthUserIDClaim,
config.GetAuthAudiences(),
config.HttpConfig.IdpSignKeyRefreshEnabled)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, integratedPeerValidator, proxyController, permissionsManager, peersManager, settingsManager)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
ephemeralManager := server.NewEphemeralManager(store, accountManager)
ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(ctx, store)
if err != nil {
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
if !disableMetrics {
idpManager := "disabled"
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
idpManager = config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
go metricsWorker.Run(ctx)
}
var compatListener net.Listener
if mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort)
if err != nil {
return err
}
log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler)
var listener net.Listener
if certManager != nil {
// a call to certManager.Listener() always creates a new listener so we do it once
cml := certManager.Listener()
if mgmtPort == 443 {
// CertManager, HTTP and gRPC API all on the same port
rootHandler = certManager.HTTPHandler(rootHandler)
listener = cml
} else {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), certManager.TLSConfig())
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
}
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(ctx, cml, certManager.HTTPHandler(nil))
}
} else if tlsConfig != nil {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig)
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
}
} else {
listener, err = net.Listen("tcp", fmt.Sprintf(":%d", mgmtPort))
if err != nil {
return fmt.Errorf("failed creating TCP listener on port %d: %v", mgmtPort, err)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
update := version.NewUpdate("nb/management")
update.SetDaemonVersion(version.NetbirdVersion())
update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
})
defer update.StopWatch()
SetupCloseHandler()
<-stopCh
integratedPeerValidator.Stop(ctx)
if geo != nil {
_ = geo.Stop()
}
ephemeralManager.Stop()
_ = appMetrics.Close()
_ = listener.Close()
if certManager != nil {
_ = certManager.Listener().Close()
}
gRPCAPIHandler.Stop()
_ = store.Close(ctx)
_ = eventStore.Close(ctx)
log.WithContext(ctx).Infof("stopped Management Service")
return nil return nil
}, },
} }
) )
func unaryInterceptor( func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
ctx context.Context, loadedConfig := &nbconfig.Config{}
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}
func notifyStop(ctx context.Context, msg string) {
select {
case stopCh <- 1:
log.WithContext(ctx).Error(msg)
default:
// stop has been already called, nothing to report
}
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(ctx, installationID)
if err != nil {
return "", err
}
return installationID, nil
}
func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
go func() {
err := grpcServer.Serve(listener)
if err != nil {
notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
}
}()
return listener, nil
}
func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
go func() {
err := http.Serve(httpListener, handler)
if err != nil {
notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err))
}
}()
}
func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
go func() {
var err error
if tlsEnabled {
err = http.Serve(listener, handler)
} else {
// the following magic is needed to support HTTP2 without TLS
// and still share a single port between gRPC and HTTP APIs
h1s := &http.Server{
Handler: h2c.NewHandler(handler, &http2.Server{}),
}
err = h1s.Serve(listener)
}
if err != nil {
select {
case stopCh <- 1:
log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err)
default:
// stop has been already called, nothing to report
}
}
}()
}
func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")
if request.ProtoMajor == 2 && grpcHeader {
gRPCHandler.ServeHTTP(writer, request)
} else {
httpHandler.ServeHTTP(writer, request)
}
})
}
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config, error) {
loadedConfig := &types.Config{}
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -535,7 +169,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config,
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation) oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(types.NONE)) { if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE)) {
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint) oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
@@ -552,7 +186,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config,
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" { if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = types.DefaultDeviceAuthFlowScope loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = nbconfig.DefaultDeviceAuthFlowScope
} }
} }
@@ -573,10 +207,6 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config,
return loadedConfig, err return loadedConfig, err
} }
func updateMgmtConfig(ctx context.Context, path string, config *types.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
// OIDCConfigResponse used for parsing OIDC config response // OIDCConfigResponse used for parsing OIDC config response
type OIDCConfigResponse struct { type OIDCConfigResponse struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
@@ -619,25 +249,6 @@ func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigRespon
return config, nil return config, nil
} }
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
// Load server's certificate and private key
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
if err != nil {
return nil, err
}
// NewDefaultAppMetrics the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
},
}
return config, nil
}
func handleRebrand(cmd *cobra.Command) error { func handleRebrand(cmd *cobra.Command) error {
var err error var err error
if logFile == defaultLogFile { if logFile == defaultLogFile {
@@ -649,7 +260,7 @@ func handleRebrand(cmd *cobra.Command) error {
} }
} }
} }
if types.MgmtConfigPath == defaultMgmtConfig { if nbconfig.MgmtConfigPath == defaultMgmtConfig {
if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) { if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir) cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir)
err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir) err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir)

View File

@@ -2,12 +2,10 @@ package cmd
import ( import (
"fmt" "fmt"
"os"
"os/signal"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server/types" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -27,6 +25,12 @@ var (
disableGeoliteUpdate bool disableGeoliteUpdate bool
idpSignKeyRefreshEnabled bool idpSignKeyRefreshEnabled bool
userDeleteFromIDPEnabled bool userDeleteFromIDPEnabled bool
mgmtPort int
mgmtMetricsPort int
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird-mgmt", Use: "netbird-mgmt",
@@ -42,8 +46,6 @@ var (
Long: "", Long: "",
SilenceUsage: true, SilenceUsage: true,
} }
// Execution control channel for stopCh signal
stopCh chan int
) )
// Execute executes the root command. // Execute executes the root command.
@@ -52,11 +54,10 @@ func Execute() error {
} }
func init() { func init() {
stopCh = make(chan int)
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&types.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain) mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.") mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")
@@ -80,15 +81,3 @@ func init() {
rootCmd.AddCommand(migrationCmd) rootCmd.AddCommand(migrationCmd)
} }
// SetupCloseHandler handles SIGTERM signal and exits with success
func SetupCloseHandler() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
for range c {
fmt.Println("\r- Ctrl+C pressed in Terminal")
stopCh <- 0
}
}()
}

View File

@@ -0,0 +1,204 @@
package server
// @note this file includes all the lower level dependencies, db, http and grpc BaseServer, metrics, logger, etc.
import (
"context"
"crypto/tls"
"net/http"
"net/netip"
"slices"
"time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
func (s *BaseServer) Metrics() telemetry.AppMetrics {
return Create(s, func() telemetry.AppMetrics {
appMetrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
log.Fatalf("error while creating app metrics: %s", err)
}
return appMetrics
})
}
func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store {
store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false)
if err != nil {
log.Fatalf("failed to create store: %v", err)
}
return store
})
}
func (s *BaseServer) EventStore() activity.Store {
return Create(s, func() activity.Store {
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
if err != nil {
log.Fatalf("failed to initialize integration metrics: %v", err)
}
eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
log.Fatalf("failed to initialize event store: %v", err)
}
if s.config.DataStoreEncryptionKey != key {
log.WithContext(context.Background()).Infof("update config with activity store key")
s.config.DataStoreEncryptionKey = key
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config)
if err != nil {
log.Fatalf("failed to update config with activity store: %v", err)
}
}
return eventStore
})
}
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
return httpAPIHandler
})
}
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
}
realipOpts := []realip.Option{
realip.WithTrustedPeers(trustedPeers),
realip.WithTrustedProxies(trustedHTTPProxies),
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
if s.config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if err != nil {
log.Fatalf("failed to create certificate manager: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
if err != nil {
log.Fatalf("cannot load TLS credentials: %v", err)
}
transportCredentials := credentials.NewTLS(tlsConfig)
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
return gRPCAPIHandler
})
}
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
// Load server's certificate and private key
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
if err != nil {
return nil, err
}
// NewDefaultAppMetrics the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
},
}
return config, nil
}
func unaryInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}

View File

@@ -1,10 +1,11 @@
package types package config
import ( import (
"net/netip" "net/netip"
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -166,7 +167,7 @@ type ProviderConfig struct {
// StoreConfig contains Store configuration // StoreConfig contains Store configuration
type StoreConfig struct { type StoreConfig struct {
Engine Engine Engine types.Engine
} }
// ReverseProxy contains reverse proxy configuration in front of management. // ReverseProxy contains reverse proxy configuration in front of management.

View File

@@ -0,0 +1,55 @@
package server
import "fmt"
// Create a dependency and add it to the BaseServer's container. A string key identifier will be based on its type definition.
func Create[T any](s Server, createFunc func() T) T {
result, _ := maybeCreate(s, createFunc)
return result
}
// CreateNamed is the same as Create but will suffix the dependency string key identifier with a custom name.
// Useful if you want to have multiple named instances of the same object type.
func CreateNamed[T any](s Server, name string, createFunc func() T) T {
result, _ := maybeCreateNamed(s, name, createFunc)
return result
}
// Inject lets you override a specific service from outside the BaseServer itself.
// This is useful for tests
func Inject[T any](c Server, thing T) {
_, _ = maybeCreate(c, func() T {
return thing
})
}
// InjectNamed is like Inject() but with a custom name.
func InjectNamed[T any](c Server, name string, thing T) {
_, _ = maybeCreateKeyed(c, name, func() T {
return thing
})
}
func maybeCreate[T any](s Server, createFunc func() T) (result T, isNew bool) {
key := fmt.Sprintf("%T", (*T)(nil))[1:]
return maybeCreateKeyed(s, key, createFunc)
}
func maybeCreateNamed[T any](s Server, name string, createFunc func() T) (result T, isNew bool) {
key := fmt.Sprintf("%T:%s", (*T)(nil), name)[1:]
return maybeCreateKeyed(s, key, createFunc)
}
func maybeCreateKeyed[T any](s Server, key string, createFunc func() T) (result T, isNew bool) {
if t, ok := s.GetContainer(key); ok {
return t.(T), false
}
t := createFunc()
s.SetContainer(key, t)
return t, true
}

View File

@@ -0,0 +1,59 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
)
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
return Create(s, func() *server.PeersUpdateManager {
return server.NewPeersUpdateManager(s.Metrics())
})
}
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())
if err != nil {
log.Errorf("failed to create integrated peer validator: %v", err)
}
return integratedPeerValidator
})
}
func (s *BaseServer) ProxyController() port_forwarding.Controller {
return Create(s, func() port_forwarding.Controller {
return integrations.NewController(s.Store())
})
}
func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager {
return Create(s, func() *server.TimeBasedAuthSecretsManager {
return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
})
}
func (s *BaseServer) AuthManager() auth.Manager {
return Create(s, func() auth.Manager {
return auth.NewManager(s.Store(),
s.config.HttpConfig.AuthIssuer,
s.config.HttpConfig.AuthAudience,
s.config.HttpConfig.AuthKeysLocation,
s.config.HttpConfig.AuthUserIDClaim,
s.config.GetAuthAudiences(),
s.config.HttpConfig.IdpSignKeyRefreshEnabled)
})
}
func (s *BaseServer) EphemeralManager() *server.EphemeralManager {
return Create(s, func() *server.EphemeralManager {
return server.NewEphemeralManager(s.Store(), s.AccountManager())
})
}

View File

@@ -0,0 +1,108 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
)
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
return Create(s, func() geolocation.Geolocation {
geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate)
if err != nil {
log.Fatalf("could not initialize geolocation service: %v", err)
}
log.Infof("geolocation service has been initialized from %s", s.config.Datadir)
return geo
})
}
func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager {
return integrations.InitPermissionsManager(s.Store())
})
}
func (s *BaseServer) UsersManager() users.Manager {
return Create(s, func() users.Manager {
return users.NewManager(s.Store())
})
}
func (s *BaseServer) SettingsManager() settings.Manager {
return Create(s, func() settings.Manager {
extraSettingsManager := integrations.NewManager(s.EventStore())
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
})
}
func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager {
return peers.NewManager(s.Store(), s.PermissionsManager())
})
}
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain,
s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
return accountManager
})
}
func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
if s.config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP manager: %v", err)
}
}
return idpManager
})
}
func (s *BaseServer) GroupsManager() groups.Manager {
return Create(s, func() groups.Manager {
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
})
}
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
})
}
func (s *BaseServer) RoutesManager() routers.Manager {
return Create(s, func() routers.Manager {
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
})
}
func (s *BaseServer) NetworksManager() networks.Manager {
return Create(s, func() networks.Manager {
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
})
}

View File

@@ -0,0 +1,341 @@
package server
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/encryption"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
const ManagementLegacyPort = 33073
type Server interface {
Start(ctx context.Context) error
Stop() error
Errors() <-chan error
GetContainer(key string) (any, bool)
SetContainer(key string, container any)
}
// Server holds the HTTP BaseServer instance.
// Add any additional fields you need, such as database connections, config, etc.
type BaseServer struct {
// config holds the server configuration
config *nbconfig.Config
// container of dependencies, each dependency is identified by a unique string.
container map[string]any
// AfterInit is a function that will be called after the server is initialized
afterInit []func(s *BaseServer)
disableMetrics bool
dnsDomain string
disableGeoliteUpdate bool
userDeleteFromIDPEnabled bool
mgmtSingleAccModeDomain string
mgmtMetricsPort int
mgmtPort int
listener net.Listener
certManager *autocert.Manager
update *version.Update
errCh chan error
wg sync.WaitGroup
cancel context.CancelFunc
}
// NewServer initializes and configures a new Server instance
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
return &BaseServer{
config: config,
container: make(map[string]any),
dnsDomain: dnsDomain,
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
disableMetrics: disableMetrics,
disableGeoliteUpdate: disableGeoliteUpdate,
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
mgmtPort: mgmtPort,
mgmtMetricsPort: mgmtMetricsPort,
}
}
func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {
s.afterInit = append(s.afterInit, fn)
}
// Start begins listening for HTTP requests on the configured address
func (s *BaseServer) Start(ctx context.Context) error {
srvCtx, cancel := context.WithCancel(ctx)
s.cancel = cancel
s.errCh = make(chan error, 4)
s.PeersManager()
s.GeoLocationManager()
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
if err != nil {
return fmt.Errorf("failed to expose metrics: %v", err)
}
s.EphemeralManager().LoadInitialPeers(srvCtx)
var tlsConfig *tls.Config
tlsEnabled := false
if s.config.HttpConfig.LetsEncryptDomain != "" {
s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if err != nil {
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
tlsEnabled = true
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
if err != nil {
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
return err
}
tlsEnabled = true
}
installationID, err := getInstallationID(srvCtx, s.Store())
if err != nil {
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
return err
}
if !s.disableMetrics {
idpManager := "disabled"
if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" {
idpManager = s.config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
go metricsWorker.Run(srvCtx)
}
var compatListener net.Listener
if s.mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = s.serveGRPC(srvCtx, s.GRPCServer(), ManagementLegacyPort)
if err != nil {
return err
}
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler())
switch {
case s.certManager != nil:
// a call to certManager.Listener() always creates a new listener so we do it once
cml := s.certManager.Listener()
if s.mgmtPort == 443 {
// CertManager, HTTP and gRPC API all on the same port
rootHandler = s.certManager.HTTPHandler(rootHandler)
s.listener = cml
} else {
s.listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort), s.certManager.TLSConfig())
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", s.mgmtPort, err)
}
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
s.serveHTTP(ctx, cml, s.certManager.HTTPHandler(nil))
}
case tlsConfig != nil:
s.listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort), tlsConfig)
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", s.mgmtPort, err)
}
default:
s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort))
if err != nil {
return fmt.Errorf("failed creating TCP listener on port %d: %v", s.mgmtPort, err)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
s.update = version.NewUpdate("nb/management")
s.update.SetDaemonVersion(version.NetbirdVersion())
s.update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
})
return nil
}
// Stop attempts a graceful shutdown, waiting up to 5 seconds for active connections to finish
func (s *BaseServer) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.IntegratedValidator().Stop(ctx)
if s.GeoLocationManager() != nil {
_ = s.GeoLocationManager().Stop()
}
s.EphemeralManager().Stop()
_ = s.Metrics().Close()
if s.listener != nil {
_ = s.listener.Close()
}
if s.certManager != nil {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
_ = s.Store().Close(ctx)
_ = s.EventStore().Close(ctx)
if s.update != nil {
s.update.StopWatch()
}
select {
case <-s.Errors():
log.WithContext(ctx).Infof("stopped Management Service")
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Done returns a channel that is closed when the server stops
func (s *BaseServer) Errors() <-chan error {
return s.errCh
}
// GetContainer retrieves a dependency from the BaseServer's container by its key
func (s *BaseServer) GetContainer(key string) (any, bool) {
container, exists := s.container[key]
return container, exists
}
// SetContainer stores a dependency in the BaseServer's container with the specified key
func (s *BaseServer) SetContainer(key string, container any) {
if _, exists := s.container[key]; exists {
log.Tracef("container with key %s already exists", key)
return
}
s.container[key] = container
log.Tracef("container with key %s set successfully", key)
}
func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")
if request.ProtoMajor == 2 && grpcHeader {
gRPCHandler.ServeHTTP(writer, request)
} else {
httpHandler.ServeHTTP(writer, request)
}
})
}
func (s *BaseServer) serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := grpcServer.Serve(listener)
if ctx.Err() != nil {
return
}
select {
case s.errCh <- err:
default:
}
}()
return listener, nil
}
func (s *BaseServer) serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := http.Serve(httpListener, handler)
if ctx.Err() != nil {
return
}
select {
case s.errCh <- err:
default:
}
}()
}
func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
var err error
if tlsEnabled {
err = http.Serve(listener, handler)
} else {
// the following magic is needed to support HTTP2 without TLS
// and still share a single port between gRPC and HTTP APIs
h1s := &http.Server{
Handler: h2c.NewHandler(handler, &http2.Server{}),
}
err = h1s.Serve(listener)
}
if ctx.Err() != nil {
return
}
select {
case s.errCh <- err:
default:
}
}()
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(ctx, installationID)
if err != nil {
return "", err
}
return installationID, nil
}

View File

@@ -1952,20 +1952,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain") return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
} }
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
var account *types.Account
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
account, err = transaction.GetAccount(ctx, accountId) ok, domain, err := transaction.IsPrimaryAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
if account.IsDomainPrimaryAccount { if ok {
return nil return nil
} }
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain) existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
// error is not a not found error // error is not a not found error
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
@@ -1981,9 +1980,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return status.Errorf(status.Internal, "cannot update account to primary") return status.Errorf(status.Internal, "cannot update account to primary")
} }
account.IsDomainPrimaryAccount = true if err := transaction.MarkAccountPrimary(ctx, accountId); err != nil {
if err := transaction.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{ log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId, "accountId": accountId,
}).Errorf("failed to update account to primary: %v", err) }).Errorf("failed to update account to primary: %v", err)
@@ -1993,10 +1990,10 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return err
} }
return account, nil return nil
} }
// propagateUserGroupMemberships propagates all account users' group memberships to their peers. // propagateUserGroupMemberships propagates all account users' group memberships to their peers.
@@ -2067,14 +2064,12 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()), Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
} }
account, err := transaction.GetAccount(ctx, accountID) err := transaction.UpdateAccountNetwork(ctx, accountID, newIPNet)
if err != nil { if err != nil {
return err return err
} }
account.Network.Net = newIPNet peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthUpdate, accountID, "", "")
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -2094,10 +2089,6 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
takenIPs = append(takenIPs, newIP) takenIPs = append(takenIPs, newIP)
} }
if err = transaction.SaveAccount(ctx, account); err != nil {
return err
}
for _, peer := range peers { for _, peer := range peers {
if err = transaction.SavePeer(ctx, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err) return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)

View File

@@ -7,7 +7,6 @@ import (
"time" "time"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -18,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
type ExternalCacheManager nbcache.UserDataCache type ExternalCacheManager nbcache.UserDataCache
@@ -120,7 +120,10 @@ type Manager interface {
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
GetStore() store.Store GetStore() store.Store
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
} }

View File

@@ -3250,11 +3250,13 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
assert.Equal(t, 0, len(account2.Users)) assert.Equal(t, 0, len(account2.Users))
assert.Equal(t, 0, len(account2.SetupKeys)) assert.Equal(t, 0, len(account2.SetupKeys))
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
account, err = manager.Store.GetAccount(ctx, account.Id)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount) assert.True(t, account.IsDomainPrimaryAccount)
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id) err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
assert.Error(t, err, "should not be able to update a second account to primary") assert.Error(t, err, "should not be able to update a second account to primary")
} }
@@ -3275,7 +3277,9 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
assert.False(t, account.IsDomainPrimaryAccount) assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain) assert.Equal(t, domain, account.Domain)
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
account, err = manager.Store.GetAccount(ctx, account.Id)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount) assert.True(t, account.IsDomainPrimaryAccount)

View File

@@ -178,6 +178,8 @@ const (
AccountNetworkRangeUpdated Activity = 87 AccountNetworkRangeUpdated Activity = 87
PeerIPUpdated Activity = 88 PeerIPUpdated Activity = 88
JobCreatedByUser Activity = 89
AccountDeleted Activity = 99999 AccountDeleted Activity = 99999
) )
@@ -284,6 +286,8 @@ var activityMap = map[Activity]Code{
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"}, AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -44,7 +45,7 @@ type GRPCServer struct {
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *types.Config config *nbconfig.Config
secretsManager SecretsManager secretsManager SecretsManager
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
@@ -56,7 +57,7 @@ type GRPCServer struct {
// NewServer creates a new Management server // NewServer creates a new Management server
func NewServer( func NewServer(
ctx context.Context, ctx context.Context,
config *types.Config, config *nbconfig.Config,
accountManager account.Manager, accountManager account.Manager,
settingsManager settings.Manager, settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager, peersUpdateManager *PeersUpdateManager,
@@ -567,24 +568,24 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil return userID, nil
} }
func ToResponseProto(configProto types.Protocol) proto.HostConfig_Protocol { func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto { switch configProto {
case types.UDP: case nbconfig.UDP:
return proto.HostConfig_UDP return proto.HostConfig_UDP
case types.DTLS: case nbconfig.DTLS:
return proto.HostConfig_DTLS return proto.HostConfig_DTLS
case types.HTTP: case nbconfig.HTTP:
return proto.HostConfig_HTTP return proto.HostConfig_HTTP
case types.HTTPS: case nbconfig.HTTPS:
return proto.HostConfig_HTTPS return proto.HostConfig_HTTPS
case types.TCP: case nbconfig.TCP:
return proto.HostConfig_TCP return proto.HostConfig_TCP
default: default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto)) panic(fmt.Errorf("unexpected config protocol type %v", configProto))
} }
} }
func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil { if config == nil {
return nil return nil
} }
@@ -662,7 +663,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
} }
} }
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse { func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
response := &proto.SyncResponse{ response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
@@ -799,7 +800,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
return nil, status.Error(codes.InvalidArgument, errMSG) return nil, status.Error(codes.InvalidArgument, errMSG)
} }
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(types.NONE) { if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available") return nil, status.Error(codes.NotFound, "no device authorization flow information available")
} }

View File

@@ -14,11 +14,11 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
) )
// Handler is a handler that returns peers of the account // Handler is a handler that returns peers of the account
@@ -32,6 +32,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS") Methods("GET", "PUT", "DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
} }
// NewHandler creates a new peers Handler // NewHandler creates a new peers Handler
@@ -41,6 +45,99 @@ func NewHandler(accountManager account.Manager) *Handler {
} }
} }
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
req := &api.JobRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
if err != nil {
util.WriteError(ctx, err, w)
return
}
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(ctx, err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
}
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
respBody := make([]*api.JobResponse, 0, len(jobs))
for _, job := range jobs {
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
return
}
respBody = append(respBody, resp)
}
util.WriteJSONObject(ctx, w, respBody)
}
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
jobID := vars["jobId"]
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
}
func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
peerToReturn := peer.Copy() peerToReturn := peer.Copy()
if peer.Status.Connected { if peer.Status.Connected {
@@ -354,6 +451,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
} }
return &api.Peer{ return &api.Peer{
CreatedAt: peer.CreatedAt,
Id: peer.ID, Id: peer.ID,
Name: peer.Name, Name: peer.Name,
Ip: peer.IP.String(), Ip: peer.IP.String(),
@@ -390,6 +488,7 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
} }
return &api.PeerBatch{ return &api.PeerBatch{
CreatedAt: peer.CreatedAt,
Id: peer.ID, Id: peer.ID,
Name: peer.Name, Name: peer.Name,
Ip: peer.IP.String(), Ip: peer.IP.String(),
@@ -419,6 +518,28 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
} }
} }
func toSingleJobResponse(job *types.Job) (*api.JobResponse, error) {
workload, err := job.BuildWorkloadResponse()
if err != nil {
return nil, err
}
var failed *string
if job.FailedReason != "" {
failed = &job.FailedReason
}
return &api.JobResponse{
Id: job.ID,
CreatedAt: job.CreatedAt,
CompletedAt: job.CompletedAt,
TriggeredBy: job.TriggeredBy,
Status: api.JobResponseStatus(job.Status),
FailedReason: failed,
Workload: *workload,
}, nil
}
func fqdn(peer *nbpeer.Peer, dnsDomain string) string { func fqdn(peer *nbpeer.Peer, dnsDomain string) string {
fqdn := peer.FQDN(dnsDomain) fqdn := peer.FQDN(dnsDomain)
if fqdn == "" { if fqdn == "" {

View File

@@ -50,23 +50,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context,
defer unlock() defer unlock()
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
a, err := transaction.GetAccount(ctx, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return err return err
} }
var extra *types.ExtraSettings var extra *types.ExtraSettings
if a.Settings.Extra != nil { if settings.Extra != nil {
extra = a.Settings.Extra extra = settings.Extra
} else { } else {
extra = &types.ExtraSettings{} extra = &types.ExtraSettings{}
a.Settings.Extra = extra settings.Extra = extra
} }
extra.IntegratedValidator = validator extra.IntegratedValidator = validator
extra.IntegratedValidatorGroups = groups extra.IntegratedValidatorGroups = groups
return transaction.SaveAccount(ctx, a) return transaction.SaveAccountSettings(ctx, accountID, settings)
}) })
} }

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -95,21 +96,21 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) { func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{ mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &config.Config{
Stuns: []*types.Host{{ Stuns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.netbird.io:3468", URI: "stun:stun.netbird.io:3468",
}}, }},
TURNConfig: &types.TURNConfig{ TURNConfig: &config.TURNConfig{
TimeBasedCredentials: false, TimeBasedCredentials: false,
CredentialsTTL: util.Duration{}, CredentialsTTL: util.Duration{},
Secret: "whatever", Secret: "whatever",
Turns: []*types.Host{{ Turns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "turn:stun.netbird.io:3468", URI: "turn:stun.netbird.io:3468",
}}, }},
}, },
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: "signal.netbird.io:10000", URI: "signal.netbird.io:10000",
}, },
@@ -332,7 +333,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
inputFlow *types.DeviceAuthorizationFlow inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string expectedErrMSG string
@@ -347,9 +348,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}, },
{ {
name: "Testing Invalid Device Flow Provider Config", name: "Testing Invalid Device Flow Provider Config",
inputFlow: &types.DeviceAuthorizationFlow{ inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe", Provider: "NoNe",
ProviderConfig: types.ProviderConfig{ ProviderConfig: config.ProviderConfig{
ClientID: "test", ClientID: "test",
}, },
}, },
@@ -358,9 +359,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}, },
{ {
name: "Testing Full Device Flow Config", name: "Testing Full Device Flow Config",
inputFlow: &types.DeviceAuthorizationFlow{ inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted", Provider: "hosted",
ProviderConfig: types.ProviderConfig{ ProviderConfig: config.ProviderConfig{
ClientID: "test", ClientID: "test",
}, },
}, },
@@ -381,7 +382,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{ mgmtServer := &GRPCServer{
wgKey: testingServerKey, wgKey: testingServerKey,
config: &types.Config{ config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow, DeviceAuthorizationFlow: testCase.inputFlow,
}, },
} }
@@ -412,7 +413,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
} }
} }
func startManagementForTest(t *testing.T, testFile string, config *types.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
@@ -515,21 +516,21 @@ func testSyncStatusRace(t *testing.T) {
t.Skip() t.Skip()
dir := t.TempDir() dir := t.TempDir()
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{ mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &config.Config{
Stuns: []*types.Host{{ Stuns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.netbird.io:3468", URI: "stun:stun.netbird.io:3468",
}}, }},
TURNConfig: &types.TURNConfig{ TURNConfig: &config.TURNConfig{
TimeBasedCredentials: false, TimeBasedCredentials: false,
CredentialsTTL: util.Duration{}, CredentialsTTL: util.Duration{},
Secret: "whatever", Secret: "whatever",
Turns: []*types.Host{{ Turns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "turn:stun.netbird.io:3468", URI: "turn:stun.netbird.io:3468",
}}, }},
}, },
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: "signal.netbird.io:10000", URI: "signal.netbird.io:10000",
}, },
@@ -687,21 +688,21 @@ func Test_LoginPerformance(t *testing.T) {
t.Helper() t.Helper()
dir := t.TempDir() dir := t.TempDir()
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{ mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &config.Config{
Stuns: []*types.Host{{ Stuns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.netbird.io:3468", URI: "stun:stun.netbird.io:3468",
}}, }},
TURNConfig: &types.TURNConfig{ TURNConfig: &config.TURNConfig{
TimeBasedCredentials: false, TimeBasedCredentials: false,
CredentialsTTL: util.Duration{}, CredentialsTTL: util.Duration{},
Secret: "whatever", Secret: "whatever",
Turns: []*types.Host{{ Turns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "turn:stun.netbird.io:3468", URI: "turn:stun.netbird.io:3468",
}}, }},
}, },
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: "signal.netbird.io:10000", URI: "signal.netbird.io:10000",
}, },

View File

@@ -20,7 +20,7 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
@@ -30,6 +30,7 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -60,7 +61,7 @@ func setupTest(t *testing.T) *testSuite {
t.Fatalf("failed to create temp directory: %v", err) t.Fatalf("failed to create temp directory: %v", err)
} }
config := &types.Config{} config := &config.Config{}
_, err = util.ReadJson("testdata/management.json", config) _, err = util.ReadJson("testdata/management.json", config)
if err != nil { if err != nil {
t.Fatalf("failed to read management.json: %v", err) t.Fatalf("failed to read management.json: %v", err)
@@ -158,7 +159,7 @@ func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClie
func startServer( func startServer(
t *testing.T, t *testing.T,
config *types.Config, config *config.Config,
dataDir string, dataDir string,
testFile string, testFile string,
) (*grpc.Server, net.Listener) { ) (*grpc.Server, net.Listener) {

View File

@@ -10,7 +10,6 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -21,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
var _ account.Manager = (*MockAccountManager)(nil) var _ account.Manager = (*MockAccountManager)(nil)
@@ -114,7 +114,7 @@ type MockAccountManager struct {
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
GetStoreFunc func() store.Store GetStoreFunc func() store.Store
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
@@ -123,6 +123,29 @@ type MockAccountManager struct {
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateAccountPeersFunc func(ctx context.Context, accountID string) UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
}
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
if am.CreatePeerJobFunc != nil {
return am.CreatePeerJobFunc(ctx, accountID, peerID, userID, job)
}
return status.Errorf(codes.Unimplemented, "method CreateJob is not implemented")
}
func (am *MockAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) {
if am.CreatePeerJobFunc != nil {
return am.GetAllPeerJobsFunc(ctx, accountID, userID, peerID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllJobs is not implemented")
}
func (am *MockAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) {
if am.CreatePeerJobFunc != nil {
return am.GetPeerJobByIDFunc(ctx, accountID, userID, peerID, jobID)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateJob is not implemented")
} }
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
@@ -933,11 +956,11 @@ func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Cont
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented") return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
} }
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
if am.UpdateToPrimaryAccountFunc != nil { if am.UpdateToPrimaryAccountFunc != nil {
return am.UpdateToPrimaryAccountFunc(ctx, accountId) return am.UpdateToPrimaryAccountFunc(ctx, accountId)
} }
return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented") return status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
} }
func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) { func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) {

View File

@@ -333,6 +333,130 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return peer, nil return peer, nil
} }
func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
// todo: Create permissions for job
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return err
}
if peerAccountID != accountID {
return status.NewPeerNotPartOfAccountError()
}
// check if peer connected
// todo: implement jobManager.IsPeerConnected
// if !am.jobManager.IsPeerConnected(ctx, peerID) {
// return status.NewJobFailedError("peer not connected")
// }
// check if already has pending jobs
// todo: implement jobManager.GetPendingJobsByPeerID
// if pending := am.jobManager.GetPendingJobsByPeerID(ctx, peerID); len(pending) > 0 {
// return status.NewJobAlreadyPendingError(peerID)
// }
// try sending job first
// todo: implement am.jobManager.SendJob
// if err := am.jobManager.SendJob(ctx, peerID, job); err != nil {
// return status.NewJobFailedError(fmt.Sprintf("failed to send job: %v", err))
// }
var peer *nbpeer.Peer
var eventsToStore func()
// persist job in DB only if send succeeded
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return err
}
if err := transaction.CreatePeerJob(ctx, job); err != nil {
return err
}
jobMeta := map[string]any{
"job_id": job.ID,
"for_peer_id": job.PeerID,
"job_type": job.Workload.Type,
"job_status": job.Status,
"job_workload": job.Workload,
}
eventsToStore = func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.JobCreatedByUser, jobMeta)
}
return nil
})
if err != nil {
return err
}
eventsToStore()
return nil
}
func (am *DefaultAccountManager) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) {
// todo: Create permissions for job
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return nil, err
}
if peerAccountID != accountID {
return []*types.Job{}, nil
}
accountJobs, err := am.Store.GetPeerJobs(ctx, accountID, peerID)
if err != nil {
return nil, err
}
return accountJobs, nil
}
func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) {
// todo: Create permissions for job
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return nil, err
}
if peerAccountID != accountID {
return &types.Job{}, nil
}
job, err := am.Store.GetPeerJobByID(ctx, accountID, jobID)
if err != nil {
return nil, err
}
return job, nil
}
// DeletePeer removes peer from the account by its IP // DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)

View File

@@ -25,6 +25,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
@@ -1063,16 +1064,16 @@ func TestToSyncResponse(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
config := &types.Config{ config := &config.Config{
Signal: &types.Host{ Signal: &config.Host{
Proto: "https", Proto: "https",
URI: "signal.uri", URI: "signal.uri",
Username: "", Username: "",
Password: "", Password: "",
}, },
Stuns: []*types.Host{{URI: "stun.uri", Proto: types.UDP}}, Stuns: []*config.Host{{URI: "stun.uri", Proto: config.UDP}},
TURNConfig: &types.TURNConfig{ TURNConfig: &config.TURNConfig{
Turns: []*types.Host{{URI: "turn.uri", Proto: types.UDP, Username: "turn-user", Password: "turn-pass"}}, Turns: []*config.Host{{URI: "turn.uri", Proto: config.UDP, Username: "turn-user", Password: "turn-pass"}},
}, },
} }
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{

View File

@@ -290,6 +290,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return transaction.DeleteRoute(ctx, accountID, string(routeID)) return transaction.DeleteRoute(ctx, accountID, string(routeID))
}) })
if err != nil {
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
}
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())

View File

@@ -38,14 +38,15 @@ import (
) )
const ( const (
storeSqliteFileName = "store.db" storeSqliteFileName = "store.db"
idQueryCondition = "id = ?" idQueryCondition = "id = ?"
keyQueryCondition = "key = ?" keyQueryCondition = "key = ?"
mysqlKeyQueryCondition = "`key` = ?" mysqlKeyQueryCondition = "`key` = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?" accountAndIDQueryCondition = "account_id = ? and id = ?"
accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountAndPeerIDQueryCondition = "account_id = ? and peer_id = ?"
accountIDCondition = "account_id = ?" accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
peerNotFoundFMT = "peer %s not found" accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
) )
// SqlStore represents an account storage backed by a Sql DB persisted to disk // SqlStore represents an account storage backed by a Sql DB persisted to disk
@@ -106,6 +107,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{},
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err) return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -124,6 +126,79 @@ func GetKeyQueryCondition(s *SqlStore) string {
return keyQueryCondition return keyQueryCondition
} }
// SaveJob persists a job in DB
func (s *SqlStore) CreatePeerJob(ctx context.Context, job *types.Job) error {
result := s.db.Create(job)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to create job in store")
}
return nil
}
// job was pending for too long and has been cancelled
// todo call it when we first start the jobChannel to make sure no stuck jobs
func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, peerID string) error {
now := time.Now().UTC()
return s.db.
Model(&types.Job{}).
Where("peer_id = ? AND status = ?", types.JobStatusPending, peerID).
Updates(map[string]any{
"status": types.JobStatusFailed,
"failed_reason": "Pending job cleanup: marked as failed automatically due to being stuck too long",
"completed_at": now,
}).Error
}
// GetJobByID fetches job by ID
func (s *SqlStore) GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error) {
var job types.Job
err := s.db.
Where(accountAndIDQueryCondition, accountID, jobID).
First(&job).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "job %s not found", jobID)
}
return &job, err
}
// get all jobs
func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error) {
var jobs []*types.Job
err := s.db.
Where(accountAndPeerIDQueryCondition, accountID, peerID).
Order("created_at DESC").
Find(&jobs).Error
if err != nil {
return nil, err
}
return jobs, nil
}
func (s *SqlStore) CompletePeerJob(accountID, jobID, result, failedReason string) error {
now := time.Now().UTC()
updates := map[string]any{
"completed_at": now,
}
if result != "" && failedReason == "" {
updates["status"] = types.JobStatusSucceeded
updates["result"] = result
updates["failed_reason"] = ""
} else {
updates["status"] = types.JobStatusFailed
updates["failed_reason"] = failedReason
}
return s.db.
Model(&types.Job{}).
Where(accountAndIDQueryCondition, accountID, jobID).
Updates(updates).Error
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring global lock") log.WithContext(ctx).Tracef("acquiring global lock")
@@ -2832,3 +2907,57 @@ func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFu
}() }()
return ctx, cancel return ctx, cancel
} }
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
var info types.PrimaryAccountInfo
result := s.db.Model(&types.Account{}).
Select("is_domain_primary_account, domain").
Where(idQueryCondition, accountID).
Take(&info)
if result.Error != nil {
return false, "", status.Errorf(status.Internal, "failed to get account info: %v", result.Error)
}
return info.IsDomainPrimaryAccount, info.Domain, nil
}
func (s *SqlStore) MarkAccountPrimary(ctx context.Context, accountID string) error {
result := s.db.Model(&types.Account{}).
Where(idQueryCondition, accountID).
Update("is_domain_primary_account", true)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark account as primary: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark account as primary")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}
type accountNetworkPatch struct {
Network *types.Network `gorm:"embedded;embeddedPrefix:network_"`
}
func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error {
patch := accountNetworkPatch{
Network: &types.Network{Net: ipNet},
}
result := s.db.WithContext(ctx).
Model(&types.Account{}).
Where(idQueryCondition, accountID).
Updates(&patch)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update account network: %v", result.Error)
return status.Errorf(status.Internal, "failed to update account network")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}

View File

@@ -202,6 +202,14 @@ type Store interface {
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error)
MarkAccountPrimary(ctx context.Context, accountID string) error
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
CreatePeerJob(ctx context.Context, job *types.Job) error
CompletePeerJob(accountID, jobID, result, failedReason string) error
GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error)
GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error)
MarkPendingJobsAsFailed(ctx context.Context, peerID string) error
} }
const ( const (

View File

@@ -12,9 +12,9 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2" authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
@@ -33,8 +33,8 @@ type SecretsManager interface {
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server // TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct { type TimeBasedAuthSecretsManager struct {
mux sync.Mutex mux sync.Mutex
turnCfg *types.TURNConfig turnCfg *nbconfig.TURNConfig
relayCfg *types.Relay relayCfg *nbconfig.Relay
turnHmacToken *auth.TimedHMAC turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager updateManager *PeersUpdateManager
@@ -46,7 +46,7 @@ type TimeBasedAuthSecretsManager struct {
type Token auth.Token type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{ mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager, updateManager: updateManager,
turnCfg: turnCfg, turnCfg: turnCfg,

View File

@@ -13,6 +13,7 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -20,8 +21,8 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
var TurnTestHost = &types.Host{ var TurnTestHost = &config.Host{
Proto: types.UDP, Proto: config.UDP,
URI: "turn:turn.netbird.io:77777", URI: "turn:turn.netbird.io:77777",
Username: "username", Username: "username",
Password: "", Password: "",
@@ -32,7 +33,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
secret := "some_secret" secret := "some_secret"
peersManager := NewPeersUpdateManager(nil) peersManager := NewPeersUpdateManager(nil)
rc := &types.Relay{ rc := &config.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
@@ -43,10 +44,10 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager) }, rc, settingsMockManager, groupsManager)
@@ -83,7 +84,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
peer := "some_peer" peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer) updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &types.Relay{ rc := &config.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
@@ -95,10 +96,10 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager) }, rc, settingsMockManager, groupsManager)
@@ -187,7 +188,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
peersManager := NewPeersUpdateManager(nil) peersManager := NewPeersUpdateManager(nil)
peer := "some_peer" peer := "some_peer"
rc := &types.Relay{ rc := &config.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
@@ -198,10 +199,10 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager) }, rc, settingsMockManager, groupsManager)

View File

@@ -16,16 +16,16 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -89,6 +89,12 @@ type Account struct {
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
} }
// this class is used by gorm only
type PrimaryAccountInfo struct {
IsDomainPrimaryAccount bool
Domain string
}
// Subclass used in gorm to only load network and not whole account // Subclass used in gorm to only load network and not whole account
type AccountNetwork struct { type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"` Network *Network `gorm:"embedded;embeddedPrefix:network_"`

View File

@@ -0,0 +1,155 @@
package types
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
type JobStatus string
const (
JobStatusPending JobStatus = "pending"
JobStatusSucceeded JobStatus = "succeeded"
JobStatusFailed JobStatus = "failed"
)
type JobType string
const (
JobTypeBundle JobType = "bundle"
)
type Job struct {
// ID is the primary identifier
ID string `gorm:"primaryKey"`
// CreatedAt when job was created (UTC)
CreatedAt time.Time `gorm:"autoCreateTime"`
// CompletedAt when job finished, null if still running
CompletedAt *time.Time
// TriggeredBy user that triggered this job
TriggeredBy string `gorm:"index"`
PeerID string `gorm:"index"`
AccountID string `gorm:"index"`
// Status of the job: pending, succeeded, failed
Status JobStatus `gorm:"index;type:varchar(50)"`
// FailedReason describes why the job failed (if failed)
FailedReason string
Workload Workload `gorm:"embedded;embeddedPrefix:workload_"`
}
type Workload struct {
Type JobType `gorm:"column:workload_type;index;type:varchar(50)"`
Parameters json.RawMessage `gorm:"type:json"`
Result json.RawMessage `gorm:"type:json"`
}
// NewJob creates a new job with default fields and validation
func NewJob(triggeredBy, accountID, peerID string, req *api.JobRequest) (*Job, error) {
if req == nil {
return nil, status.Errorf(status.BadRequest, "job request cannot be nil")
}
// Determine job type
jobTypeStr, err := req.Workload.Discriminator()
if err != nil {
return nil, status.Errorf(status.BadRequest, "could not determine job type: %v", err)
}
jobType := JobType(jobTypeStr)
if jobType == "" {
return nil, status.Errorf(status.BadRequest, "job type is required")
}
var workload Workload
switch jobType {
case JobTypeBundle:
if err := validateAndBuildBundleParams(req.Workload, &workload); err != nil {
return nil, status.Errorf(status.BadRequest, "%v", err)
}
default:
return nil, status.Errorf(status.BadRequest, "unsupported job type: %s", jobType)
}
return &Job{
ID: uuid.New().String(),
TriggeredBy: triggeredBy,
PeerID: peerID,
AccountID: accountID,
Status: JobStatusPending,
CreatedAt: time.Now().UTC(),
Workload: workload,
}, nil
}
func (j *Job) BuildWorkloadResponse() (*api.WorkloadResponse, error) {
var wl api.WorkloadResponse
switch j.Workload.Type {
case JobTypeBundle:
if err := j.buildBundleResponse(&wl); err != nil {
return nil, status.Errorf(status.InvalidArgument, err.Error())
}
return &wl, nil
default:
return nil, status.Errorf(status.InvalidArgument, "unknown job type: %v", j.Workload.Type)
}
}
func (j *Job) buildBundleResponse(wl *api.WorkloadResponse) error {
var p api.BundleParameters
if err := json.Unmarshal(j.Workload.Parameters, &p); err != nil {
return fmt.Errorf("invalid parameters for bundle job: %w", err)
}
var r api.BundleResult
if err := json.Unmarshal(j.Workload.Result, &r); err != nil {
return fmt.Errorf("invalid result for bundle job: %w", err)
}
if err := wl.FromBundleWorkloadResponse(api.BundleWorkloadResponse{
Type: api.WorkloadTypeBundle,
Parameters: p,
Result: r,
}); err != nil {
return fmt.Errorf("unknown job parameters: %v", err)
}
return nil
}
func validateAndBuildBundleParams(req api.WorkloadRequest, workload *Workload) error {
bundle, err := req.AsBundleWorkloadRequest()
if err != nil {
return fmt.Errorf("invalid parameters for bundle job")
}
// validate bundle_for_time <= 5 minutes
if bundle.Parameters.BundleForTime < 0 || bundle.Parameters.BundleForTime > 5 {
return fmt.Errorf("bundle_for_time must be between 0 and 5, got %d", bundle.Parameters.BundleForTime)
}
// validate log-file-count ≥ 1 and ≤ 1000
if bundle.Parameters.LogFileCount < 1 || bundle.Parameters.LogFileCount > 1000 {
return fmt.Errorf("log-file-count must be between 1 and 1000, got %d", bundle.Parameters.LogFileCount)
}
workload.Parameters, err = json.Marshal(bundle.Parameters)
if err != nil {
return fmt.Errorf("failed to marshal workload parameters: %w", err)
}
workload.Result = []byte("{}")
workload.Type = JobType(api.WorkloadTypeBundle)
return nil
}

View File

@@ -111,7 +111,11 @@ type Route struct {
// EventMeta returns activity event meta related to the route // EventMeta returns activity event meta related to the route
func (r *Route) EventMeta() map[string]any { func (r *Route) EventMeta() map[string]any {
return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "domains": r.Domains.SafeString(), "peer_id": r.Peer, "peer_groups": r.PeerGroups} domains := ""
if r.Domains != nil {
domains = r.Domains.SafeString()
}
return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "domains": domains, "peer_id": r.Peer, "peer_groups": r.PeerGroups}
} }
// Copy copies a route object // Copy copies a route object
@@ -181,7 +185,7 @@ func (r *Route) GetResourceID() ResID {
// If the route is dynamic, it returns the domains as comma-separated punycode-encoded string. // If the route is dynamic, it returns the domains as comma-separated punycode-encoded string.
// If the route is not dynamic, it returns the network (prefix) string. // If the route is not dynamic, it returns the network (prefix) string.
func (r *Route) NetString() string { func (r *Route) NetString() string {
if r.IsDynamic() { if r.IsDynamic() && r.Domains != nil {
return r.Domains.SafeString() return r.Domains.SafeString()
} }
return r.Network.String() return r.Network.String()

View File

@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -27,9 +28,9 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
mgmt "github.com/netbirdio/netbird/management/server" mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -52,7 +53,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
level, _ := log.ParseLevel("debug") level, _ := log.ParseLevel("debug")
log.SetLevel(level) log.SetLevel(level)
config := &types.Config{} config := &config.Config{}
_, err := util.ReadJson("../../../management/server/testdata/management.json", config) _, err := util.ReadJson("../../../management/server/testdata/management.json", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -11,6 +11,6 @@ fi
old_pwd=$(pwd) old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0")) script_path=$(dirname $(realpath "$0"))
cd "$script_path" cd "$script_path"
go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@4a1477f6a8ba6ca8115cc23bb2fb67f0b9fca18e go install github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen@latest
oapi-codegen --config cfg.yaml openapi.yml oapi-codegen --config cfg.yaml openapi.yml
cd "$old_pwd" cd "$old_pwd"

View File

@@ -34,6 +34,111 @@ tags:
x-cloud-only: true x-cloud-only: true
components: components:
schemas: schemas:
WorkloadType:
type: string
enum:
- bundle
BundleParameters:
type: object
properties:
bundle_for:
type: boolean
example: true
bundle_for_time:
type: integer
minimum: 0
example: 2
log_file_count:
type: integer
minimum: 0
example: 100
anonymize:
type: boolean
example: false
required:
- bundle_for
- bundle_for_time
- log_file_count
- anonymize
BundleResult:
type: object
properties:
upload_key:
type: string
example: "upload_key_123"
nullable: true
BundleWorkloadRequest:
type: object
properties:
type:
$ref: '#/components/schemas/WorkloadType'
parameters:
$ref: '#/components/schemas/BundleParameters'
required:
- type
- parameters
BundleWorkloadResponse:
type: object
properties:
type:
$ref: '#/components/schemas/WorkloadType'
parameters:
$ref: '#/components/schemas/BundleParameters'
result:
$ref: '#/components/schemas/BundleResult'
required:
- type
- parameters
- result
WorkloadRequest:
oneOf:
- $ref: '#/components/schemas/BundleWorkloadRequest'
discriminator:
propertyName: type
mapping:
bundle: '#/components/schemas/BundleWorkloadRequest'
WorkloadResponse:
oneOf:
- $ref: '#/components/schemas/BundleWorkloadResponse'
discriminator:
propertyName: type
mapping:
bundle: '#/components/schemas/BundleWorkloadResponse'
JobRequest:
type: object
properties:
workload:
$ref: '#/components/schemas/WorkloadRequest'
required:
- workload
JobResponse:
type: object
properties:
id:
type: string
created_at:
type: string
format: date-time
completed_at:
type: string
format: date-time
nullable: true
triggered_by:
type: string
status:
type: string
enum: [pending, succeeded, failed]
failed_reason:
type: string
nullable: true
workload:
$ref: '#/components/schemas/WorkloadResponse'
required:
- id
- created_at
- status
- triggered_by
- workload
Account: Account:
type: object type: object
properties: properties:
@@ -369,6 +474,11 @@ components:
- $ref: '#/components/schemas/PeerMinimum' - $ref: '#/components/schemas/PeerMinimum'
- type: object - type: object
properties: properties:
created_at:
description: Peer creation date (UTC)
type: string
format: date-time
example: "2023-05-05T09:00:35.477782Z"
ip: ip:
description: Peer's IP address description: Peer's IP address
type: string type: string
@@ -471,6 +581,7 @@ components:
- connected - connected
- connection_ip - connection_ip
- country_code - country_code
- created_at
- dns_label - dns_label
- geoname_id - geoname_id
- groups - groups
@@ -544,11 +655,17 @@ components:
- $ref: '#/components/schemas/Peer' - $ref: '#/components/schemas/Peer'
- type: object - type: object
properties: properties:
created_at:
description: Peer creation date (UTC)
type: string
format: date-time
example: "2023-05-05T09:00:35.477782Z"
accessible_peers_count: accessible_peers_count:
description: Number of accessible peers description: Number of accessible peers
type: integer type: integer
example: 5 example: 5
required: required:
- created_at
- accessible_peers_count - accessible_peers_count
SetupKeyBase: SetupKeyBase:
type: object type: object
@@ -2158,6 +2275,108 @@ security:
- BearerAuth: [ ] - BearerAuth: [ ]
- TokenAuth: [ ] - TokenAuth: [ ]
paths: paths:
/api/peers/{peerId}/jobs:
get:
summary: List Jobs
description: Retrieve all jobs for a given peer
tags: [ Jobs ]
security:
- BearerAuth: []
- TokenAuth: []
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
responses:
'200':
description: List of jobs
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/JobResponse'
'400':
$ref: '#/components/responses/bad_request'
'401':
$ref: '#/components/responses/requires_authentication'
'403':
$ref: '#/components/responses/forbidden'
'500':
$ref: '#/components/responses/internal_error'
post:
summary: Create Job
description: Create a new job for a given peer
tags: [ Jobs ]
security:
- BearerAuth: []
- TokenAuth: []
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
requestBody:
description: Create job request
content:
application/json:
schema:
$ref: '#/components/schemas/JobRequest'
required: true
responses:
'201':
description: Job created
content:
application/json:
schema:
$ref: '#/components/schemas/JobResponse'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers/{peerId}/jobs/{jobId}:
get:
summary: Get Job
description: Retrieve details of a specific job
tags: [ Jobs ]
security:
- BearerAuth: []
- TokenAuth: []
parameters:
- in: path
name: peerId
required: true
schema:
type: string
- in: path
name: jobId
required: true
schema:
type: string
responses:
'200':
description: A Job object
content:
application/json:
schema:
$ref: '#/components/schemas/JobResponse'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/accounts: /api/accounts:
get: get:
summary: List all Accounts summary: List all Accounts

View File

@@ -1,10 +1,14 @@
// Package api provides primitives to interact with the openapi HTTP API. // Package api provides primitives to interact with the openapi HTTP API.
// //
// Code generated by github.com/deepmap/oapi-codegen version v1.11.1-0.20220912230023-4a1477f6a8ba DO NOT EDIT. // Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.0 DO NOT EDIT.
package api package api
import ( import (
"encoding/json"
"errors"
"time" "time"
"github.com/oapi-codegen/runtime"
) )
const ( const (
@@ -104,6 +108,13 @@ const (
IngressPortAllocationRequestPortRangeProtocolUdp IngressPortAllocationRequestPortRangeProtocol = "udp" IngressPortAllocationRequestPortRangeProtocolUdp IngressPortAllocationRequestPortRangeProtocol = "udp"
) )
// Defines values for JobResponseStatus.
const (
JobResponseStatusFailed JobResponseStatus = "failed"
JobResponseStatusPending JobResponseStatus = "pending"
JobResponseStatusSucceeded JobResponseStatus = "succeeded"
)
// Defines values for NameserverNsType. // Defines values for NameserverNsType.
const ( const (
NameserverNsTypeUdp NameserverNsType = "udp" NameserverNsTypeUdp NameserverNsType = "udp"
@@ -178,6 +189,11 @@ const (
UserStatusInvited UserStatus = "invited" UserStatusInvited UserStatus = "invited"
) )
// Defines values for WorkloadType.
const (
WorkloadTypeBundle WorkloadType = "bundle"
)
// Defines values for GetApiEventsNetworkTrafficParamsType. // Defines values for GetApiEventsNetworkTrafficParamsType.
const ( const (
GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP" GetApiEventsNetworkTrafficParamsTypeTYPEDROP GetApiEventsNetworkTrafficParamsType = "TYPE_DROP"
@@ -337,6 +353,32 @@ type AvailablePorts struct {
Udp int `json:"udp"` Udp int `json:"udp"`
} }
// BundleParameters defines model for BundleParameters.
type BundleParameters struct {
Anonymize bool `json:"anonymize"`
BundleFor bool `json:"bundle_for"`
BundleForTime int `json:"bundle_for_time"`
LogFileCount int `json:"log_file_count"`
}
// BundleResult defines model for BundleResult.
type BundleResult struct {
UploadKey *string `json:"upload_key"`
}
// BundleWorkloadRequest defines model for BundleWorkloadRequest.
type BundleWorkloadRequest struct {
Parameters BundleParameters `json:"parameters"`
Type WorkloadType `json:"type"`
}
// BundleWorkloadResponse defines model for BundleWorkloadResponse.
type BundleWorkloadResponse struct {
Parameters BundleParameters `json:"parameters"`
Result BundleResult `json:"result"`
Type WorkloadType `json:"type"`
}
// Checks List of objects that perform the actual checks // Checks List of objects that perform the actual checks
type Checks struct { type Checks struct {
// GeoLocationCheck Posture check for geo location // GeoLocationCheck Posture check for geo location
@@ -643,6 +685,25 @@ type IngressPortAllocationRequestPortRange struct {
// IngressPortAllocationRequestPortRangeProtocol The protocol accepted by the port range // IngressPortAllocationRequestPortRangeProtocol The protocol accepted by the port range
type IngressPortAllocationRequestPortRangeProtocol string type IngressPortAllocationRequestPortRangeProtocol string
// JobRequest defines model for JobRequest.
type JobRequest struct {
Workload WorkloadRequest `json:"workload"`
}
// JobResponse defines model for JobResponse.
type JobResponse struct {
CompletedAt *time.Time `json:"completed_at"`
CreatedAt time.Time `json:"created_at"`
FailedReason *string `json:"failed_reason"`
Id string `json:"id"`
Status JobResponseStatus `json:"status"`
TriggeredBy string `json:"triggered_by"`
Workload WorkloadResponse `json:"workload"`
}
// JobResponseStatus defines model for JobResponse.Status.
type JobResponseStatus string
// Location Describe geographical location information // Location Describe geographical location information
type Location struct { type Location struct {
// CityName Commonly used English name of the city // CityName Commonly used English name of the city
@@ -1030,6 +1091,9 @@ type Peer struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"` CountryCode CountryCode `json:"country_code"`
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
@@ -1114,6 +1178,9 @@ type PeerBatch struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"` CountryCode CountryCode `json:"country_code"`
// CreatedAt Peer creation date (UTC)
CreatedAt time.Time `json:"created_at"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
@@ -1814,6 +1881,19 @@ type UserRequest struct {
Role string `json:"role"` Role string `json:"role"`
} }
// WorkloadRequest defines model for WorkloadRequest.
type WorkloadRequest struct {
union json.RawMessage
}
// WorkloadResponse defines model for WorkloadResponse.
type WorkloadResponse struct {
union json.RawMessage
}
// WorkloadType defines model for WorkloadType.
type WorkloadType string
// GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. // GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic.
type GetApiEventsNetworkTrafficParams struct { type GetApiEventsNetworkTrafficParams struct {
// Page Page number // Page Page number
@@ -1931,6 +2011,9 @@ type PostApiPeersPeerIdIngressPortsJSONRequestBody = IngressPortAllocationReques
// PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody defines body for PutApiPeersPeerIdIngressPortsAllocationId for application/json ContentType. // PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody defines body for PutApiPeersPeerIdIngressPortsAllocationId for application/json ContentType.
type PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody = IngressPortAllocationRequest type PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody = IngressPortAllocationRequest
// PostApiPeersPeerIdJobsJSONRequestBody defines body for PostApiPeersPeerIdJobs for application/json ContentType.
type PostApiPeersPeerIdJobsJSONRequestBody = JobRequest
// PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType. // PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType.
type PostApiPoliciesJSONRequestBody = PolicyUpdate type PostApiPoliciesJSONRequestBody = PolicyUpdate
@@ -1963,3 +2046,121 @@ type PutApiUsersUserIdJSONRequestBody = UserRequest
// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType. // PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType.
type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest
// AsBundleWorkloadRequest returns the union data inside the WorkloadRequest as a BundleWorkloadRequest
func (t WorkloadRequest) AsBundleWorkloadRequest() (BundleWorkloadRequest, error) {
var body BundleWorkloadRequest
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromBundleWorkloadRequest overwrites any union data inside the WorkloadRequest as the provided BundleWorkloadRequest
func (t *WorkloadRequest) FromBundleWorkloadRequest(v BundleWorkloadRequest) error {
v.Type = "bundle"
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeBundleWorkloadRequest performs a merge with any union data inside the WorkloadRequest, using the provided BundleWorkloadRequest
func (t *WorkloadRequest) MergeBundleWorkloadRequest(v BundleWorkloadRequest) error {
v.Type = "bundle"
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t WorkloadRequest) Discriminator() (string, error) {
var discriminator struct {
Discriminator string `json:"type"`
}
err := json.Unmarshal(t.union, &discriminator)
return discriminator.Discriminator, err
}
func (t WorkloadRequest) ValueByDiscriminator() (interface{}, error) {
discriminator, err := t.Discriminator()
if err != nil {
return nil, err
}
switch discriminator {
case "bundle":
return t.AsBundleWorkloadRequest()
default:
return nil, errors.New("unknown discriminator value: " + discriminator)
}
}
func (t WorkloadRequest) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
return b, err
}
func (t *WorkloadRequest) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
return err
}
// AsBundleWorkloadResponse returns the union data inside the WorkloadResponse as a BundleWorkloadResponse
func (t WorkloadResponse) AsBundleWorkloadResponse() (BundleWorkloadResponse, error) {
var body BundleWorkloadResponse
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromBundleWorkloadResponse overwrites any union data inside the WorkloadResponse as the provided BundleWorkloadResponse
func (t *WorkloadResponse) FromBundleWorkloadResponse(v BundleWorkloadResponse) error {
v.Type = "bundle"
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeBundleWorkloadResponse performs a merge with any union data inside the WorkloadResponse, using the provided BundleWorkloadResponse
func (t *WorkloadResponse) MergeBundleWorkloadResponse(v BundleWorkloadResponse) error {
v.Type = "bundle"
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t WorkloadResponse) Discriminator() (string, error) {
var discriminator struct {
Discriminator string `json:"type"`
}
err := json.Unmarshal(t.union, &discriminator)
return discriminator.Discriminator, err
}
func (t WorkloadResponse) ValueByDiscriminator() (interface{}, error) {
discriminator, err := t.Discriminator()
if err != nil {
return nil, err
}
switch discriminator {
case "bundle":
return t.AsBundleWorkloadResponse()
default:
return nil, errors.New("unknown discriminator value: " + discriminator)
}
}
func (t WorkloadResponse) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
return b, err
}
func (t *WorkloadResponse) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
return err
}

View File

@@ -52,7 +52,7 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
} }
// MarshalCredential marshal a Credential instance and returns a Message object // MarshalCredential marshal a Credential instance and returns a Message object
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string) (*proto.Message, error) { func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string, sessionID []byte) (*proto.Message, error) {
return &proto.Message{ return &proto.Message{
Key: myKey.PublicKey().String(), Key: myKey.PublicKey().String(),
RemoteKey: remoteKey, RemoteKey: remoteKey,
@@ -66,6 +66,7 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credenti
RosenpassServerAddr: rosenpassAddr, RosenpassServerAddr: rosenpassAddr,
}, },
RelayServerAddress: relaySrvAddress, RelayServerAddress: relaySrvAddress,
SessionId: sessionID,
}, },
}, nil }, nil
} }

View File

@@ -45,19 +45,10 @@ type GrpcClient struct {
connStateCallbackLock sync.RWMutex connStateCallbackLock sync.RWMutex
onReconnectedListenerFn func() onReconnectedListenerFn func()
}
func (c *GrpcClient) StreamConnected() bool { decryptionWorker *Worker
return c.status == StreamConnected decryptionWorkerCancel context.CancelFunc
} decryptionWg sync.WaitGroup
func (c *GrpcClient) GetStatus() Status {
return c.status
}
// Close Closes underlying connections to the Signal Exchange
func (c *GrpcClient) Close() error {
return c.signalConn.Close()
} }
// NewClient creates a new Signal client // NewClient creates a new Signal client
@@ -93,6 +84,25 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
}, nil }, nil
} }
func (c *GrpcClient) StreamConnected() bool {
return c.status == StreamConnected
}
func (c *GrpcClient) GetStatus() Status {
return c.status
}
// Close Closes underlying connections to the Signal Exchange
func (c *GrpcClient) Close() error {
if c.decryptionWorkerCancel != nil {
c.decryptionWorkerCancel()
}
c.decryptionWg.Wait()
c.decryptionWorker = nil
return c.signalConn.Close()
}
// SetConnStateListener set the ConnStateNotifier // SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) { func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock() c.connStateCallbackLock.Lock()
@@ -148,8 +158,12 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
log.Infof("connected to the Signal Service stream") log.Infof("connected to the Signal Service stream")
c.notifyConnected() c.notifyConnected()
// Start worker pool if not already started
c.startEncryptionWorker(msgHandler)
// start receiving messages from the Signal stream (from other peers through signal) // start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler) err = c.receive(stream)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
log.Debugf("signal connection context has been canceled, this usually indicates shutdown") log.Debugf("signal connection context has been canceled, this usually indicates shutdown")
@@ -174,6 +188,7 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
return nil return nil
} }
func (c *GrpcClient) notifyStreamDisconnected() { func (c *GrpcClient) notifyStreamDisconnected() {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
@@ -382,11 +397,11 @@ func (c *GrpcClient) Send(msg *proto.Message) error {
} }
// receive receives messages from other peers coming through the Signal Exchange // receive receives messages from other peers coming through the Signal Exchange
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient, // and distributes them to worker threads for processing
msgHandler func(msg *proto.Message) error) error { func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) error {
for { for {
msg, err := stream.Recv() msg, err := stream.Recv()
// Handle errors immediately
switch s, ok := status.FromError(err); { switch s, ok := status.FromError(err); {
case ok && s.Code() == codes.Canceled: case ok && s.Code() == codes.Canceled:
log.Debugf("stream canceled (usually indicates shutdown)") log.Debugf("stream canceled (usually indicates shutdown)")
@@ -398,24 +413,37 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
log.Debugf("Signal Service stream closed by server") log.Debugf("Signal Service stream closed by server")
return err return err
case err != nil: case err != nil:
log.Errorf("Stream receive error: %v", err)
return err return err
} }
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg) if msg == nil {
if err != nil { continue
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
} }
err = msgHandler(decryptedMessage) if err := c.decryptionWorker.AddMsg(c.ctx, msg); err != nil {
log.Errorf("failed to add message to decryption worker: %v", err)
if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
// todo send something??
} }
} }
} }
func (c *GrpcClient) startEncryptionWorker(handler func(msg *proto.Message) error) {
if c.decryptionWorker != nil {
return
}
c.decryptionWorker = NewWorker(c.decryptMessage, handler)
workerCtx, workerCancel := context.WithCancel(context.Background())
c.decryptionWorkerCancel = workerCancel
c.decryptionWg.Add(1)
go func() {
defer workerCancel()
c.decryptionWorker.Work(workerCtx)
c.decryptionWg.Done()
}()
}
func (c *GrpcClient) notifyDisconnected(err error) { func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock() c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock() defer c.connStateCallbackLock.RUnlock()

View File

@@ -0,0 +1,55 @@
package client
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/signal/proto"
)
type Worker struct {
decryptMessage func(msg *proto.EncryptedMessage) (*proto.Message, error)
handler func(msg *proto.Message) error
encryptedMsgPool chan *proto.EncryptedMessage
}
func NewWorker(decryptFn func(msg *proto.EncryptedMessage) (*proto.Message, error), handlerFn func(msg *proto.Message) error) *Worker {
return &Worker{
decryptMessage: decryptFn,
handler: handlerFn,
encryptedMsgPool: make(chan *proto.EncryptedMessage, 1),
}
}
func (w *Worker) AddMsg(ctx context.Context, msg *proto.EncryptedMessage) error {
// this is blocker because do not want to drop messages here
select {
case w.encryptedMsgPool <- msg:
case <-ctx.Done():
}
return nil
}
func (w *Worker) Work(ctx context.Context) {
for {
select {
case msg := <-w.encryptedMsgPool:
decryptedMessage, err := w.decryptMessage(msg)
if err != nil {
log.Errorf("failed to decrypt message: %v", err)
continue
}
if err := w.handler(decryptedMessage); err != nil {
log.Errorf("failed to handle message: %v", err)
continue
}
case <-ctx.Done():
log.Infof("Message worker stopping due to context cancellation")
return
}
}
}

View File

@@ -230,6 +230,7 @@ type Body struct {
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"` RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
// relayServerAddress is url of the relay server // relayServerAddress is url of the relay server
RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"` RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"`
} }
func (x *Body) Reset() { func (x *Body) Reset() {
@@ -320,6 +321,13 @@ func (x *Body) GetRelayServerAddress() string {
return "" return ""
} }
func (x *Body) GetSessionId() []byte {
if x != nil {
return x.SessionId
}
return nil
}
// Mode indicates a connection mode // Mode indicates a connection mode
type Mode struct { type Mode struct {
state protoimpl.MessageState state protoimpl.MessageState
@@ -443,7 +451,7 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62, 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xb3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xe4, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, 0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
@@ -466,34 +474,37 @@ var file_signalexchange_proto_rawDesc = []byte{
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x73,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x12, 0x0b, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06,
0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x22, 0x2e, 0x0a, 0x04, 0x4d, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44,
0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10,
0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01, 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x0c,
0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52, 0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x2e, 0x0a, 0x04,
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01,
0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f,
0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b,
0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72,
0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c,
0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d,
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e,
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (
@@ -601,6 +612,7 @@ func file_signalexchange_proto_init() {
} }
} }
} }
file_signalexchange_proto_msgTypes[2].OneofWrappers = []interface{}{}
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{} file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{

View File

@@ -64,6 +64,8 @@ message Body {
// relayServerAddress is url of the relay server // relayServerAddress is url of the relay server
string relayServerAddress = 8; string relayServerAddress = 8;
optional bytes sessionId = 10;
} }
// Mode indicates a connection mode // Mode indicates a connection mode