mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Compare commits
20 Commits
v0.59.12
...
fix/relay-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca9985d2e3 | ||
|
|
0781908df5 | ||
|
|
72f65c63d3 | ||
|
|
605d44c4f9 | ||
|
|
e4b41d0ad7 | ||
|
|
9cc9462cd5 | ||
|
|
3176b53968 | ||
|
|
27957036c9 | ||
|
|
6fb568728f | ||
|
|
cc97cffff1 | ||
|
|
c28275611b | ||
|
|
56f169eede | ||
|
|
07cf9d5895 | ||
|
|
7df49e249d | ||
|
|
dbfc8a52c9 | ||
|
|
98ddac07bf | ||
|
|
48475ddc05 | ||
|
|
6aa4ba7af4 | ||
|
|
2e16c9914a | ||
|
|
5c29d395b2 |
73
.github/workflows/check-license-dependencies.yml
vendored
73
.github/workflows/check-license-dependencies.yml
vendored
@@ -3,10 +3,19 @@ name: Check License Dependencies
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
|
||||
jobs:
|
||||
check-dependencies:
|
||||
check-internal-dependencies:
|
||||
name: Check Internal AGPL Dependencies
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -33,9 +42,67 @@ jobs:
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo ""
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages will change license and should not be imported by client or shared code"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo "✅ All license dependencies are clean"
|
||||
echo "✅ All internal license dependencies are clean"
|
||||
fi
|
||||
|
||||
check-external-licenses:
|
||||
name: Check External GPL/AGPL Licenses
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
|
||||
3
.github/workflows/golang-test-linux.yml
vendored
3
.github/workflows/golang-test-linux.yml
vendored
@@ -217,7 +217,7 @@ jobs:
|
||||
- arch: "386"
|
||||
raceFlag: ""
|
||||
- arch: "amd64"
|
||||
raceFlag: "-race"
|
||||
raceFlag: "-race -v"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -258,6 +258,7 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test ${{ matrix.raceFlag }} \
|
||||
-tags devcert \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./relay/... ./shared/relay/...
|
||||
|
||||
|
||||
@@ -259,6 +259,7 @@ func isServiceRunning() (bool, error) {
|
||||
}
|
||||
|
||||
const (
|
||||
networkdConf = "/etc/systemd/networkd.conf"
|
||||
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
||||
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
||||
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
||||
@@ -273,12 +274,16 @@ ManageForeignRoutingPolicyRules=no
|
||||
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
||||
// systemd-networkd from removing NetBird's routes and policy rules.
|
||||
func configureSystemdNetworkd() error {
|
||||
parentDir := filepath.Dir(networkdConfDir)
|
||||
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
|
||||
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
|
||||
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
|
||||
log.Debug("systemd-networkd not in use, skipping configuration")
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
|
||||
return fmt.Errorf("create networkd.conf.d directory: %w", err)
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
||||
return fmt.Errorf("write networkd configuration: %w", err)
|
||||
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
@@ -84,7 +87,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
@@ -110,13 +112,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nadoo/ipset"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -40,19 +41,13 @@ type aclManager struct {
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
return &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
}
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
@@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering(
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
@@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering(
|
||||
}}, nil
|
||||
}
|
||||
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
if err := m.createIPSet(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("create ipset: %w", err)
|
||||
}
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
ipList := newIpList(ip.String())
|
||||
@@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
shouldDestroyIpset := false
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
ip := net.ParseIP(r.ip)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("parse IP %s", r.ip)
|
||||
}
|
||||
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(ipsetList.ips, r.ip)
|
||||
}
|
||||
@@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
shouldDestroyIpset = true
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||
@@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDestroyIpset {
|
||||
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy empty ipset: %v", err)
|
||||
} else {
|
||||
log.Errorf("destroy empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
@@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error {
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
m.ipsetStore.deleteIpset(ipsetName)
|
||||
}
|
||||
@@ -368,8 +387,8 @@ func (m *aclManager) updateState() {
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
if ip.IsUnspecified() {
|
||||
matchByIP = false
|
||||
}
|
||||
|
||||
@@ -416,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
||||
return ipsetName + actionSuffix
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
}
|
||||
|
||||
if err := ipset.Del(name, entry); err != nil {
|
||||
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) flushIPSet(name string) error {
|
||||
return ipset.Flush(name)
|
||||
}
|
||||
|
||||
func (m *aclManager) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/nadoo/ipset"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
@@ -107,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
|
||||
},
|
||||
)
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@@ -232,12 +228,12 @@ func (r *router) findSets(rule []string) []string {
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||
if err := r.createIPSet(setName); err != nil {
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
@@ -246,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string) error {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
if err := r.destroyIPSet(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
@@ -915,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
@@ -993,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string {
|
||||
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
|
||||
func (r *router) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
ip := addr.AsSlice()
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: uint8(prefix.Bits()),
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -9,13 +10,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// keep darwin compatibility
|
||||
@@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
addr := "100.64.0.1/8"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
||||
func Test_CreateInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||
wgIP := "10.99.99.1/32"
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -166,7 +167,7 @@ func Test_Close(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||
wgIP := "10.99.99.2/32"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||
wgIP := "10.99.99.2/32"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||
wgIP := "10.99.99.5/30"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
func Test_UpdatePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.9/30"
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
func Test_RemovePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.13/30"
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
peer2wgPort := 33200
|
||||
|
||||
keepAlive := 1 * time.Second
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
|
||||
newNet, err = stdnet.NewNet()
|
||||
newNet, err = stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package udpmux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -12,8 +13,9 @@ import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun/v3"
|
||||
"github.com/pion/transport/v3"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
||||
if len(networks) > 0 {
|
||||
if m.params.Net == nil {
|
||||
var err error
|
||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
|
||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,7 +253,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), &relayClient.ManagerOpts{
|
||||
MTU: engineConfig.MTU,
|
||||
})
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
if len(relayURLs) > 0 {
|
||||
if token != nil {
|
||||
|
||||
@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Fatalf("create stdnet: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -7,5 +7,5 @@ import (
|
||||
)
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(e.config.IFaceBlackList)
|
||||
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList)
|
||||
}
|
||||
|
||||
@@ -3,5 +3,5 @@ package internal
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -25,8 +24,14 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
@@ -224,7 +229,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
@@ -370,7 +375,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
@@ -597,7 +602,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
@@ -762,7 +767,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
@@ -771,7 +776,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -964,7 +969,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
@@ -974,7 +979,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1496,7 +1501,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
MTU: iface.DefaultMTU,
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), &relayClient.ManagerOpts{MTU: iface.DefaultMTU})
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
@@ -1556,7 +1561,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -1584,13 +1588,16 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
||||
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
||||
log.Debugf("Gathering ICE candidates")
|
||||
|
||||
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
||||
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("create ICE agent: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -22,6 +23,8 @@ const (
|
||||
iceFailedTimeoutDefault = 6 * time.Second
|
||||
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
|
||||
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
|
||||
// iceAgentCloseTimeout is the maximum time to wait for ICE agent close to complete
|
||||
iceAgentCloseTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
type ThreadSafeAgent struct {
|
||||
@@ -32,18 +35,28 @@ type ThreadSafeAgent struct {
|
||||
func (a *ThreadSafeAgent) Close() error {
|
||||
var err error
|
||||
a.once.Do(func() {
|
||||
err = a.Agent.Close()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- a.Agent.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-done:
|
||||
case <-time.After(iceAgentCloseTimeout):
|
||||
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
|
||||
err = nil
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||
iceKeepAlive := iceKeepAlive()
|
||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||
iceFailedTimeout := iceFailedTimeout()
|
||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||
|
||||
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
|
||||
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(ifaceBlacklist)
|
||||
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(ctx, ifaceBlacklist)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package ice
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
import (
|
||||
"context"
|
||||
|
||||
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
|
||||
}
|
||||
|
||||
@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
|
||||
}
|
||||
|
||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create agent: %w", err)
|
||||
}
|
||||
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
|
||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||
if isController(w.config) {
|
||||
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
|
||||
}
|
||||
}()
|
||||
|
||||
net, err := stdnet.NewNet(nil)
|
||||
net, err := stdnet.NewNet(ctx, nil)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("new net: %w", err)
|
||||
return
|
||||
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
|
||||
}
|
||||
}()
|
||||
|
||||
net, err := stdnet.NewNet(nil)
|
||||
net, err := stdnet.NewNet(ctx, nil)
|
||||
if err != nil {
|
||||
probeErr = fmt.Errorf("new net: %w", err)
|
||||
return
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
|
||||
@@ -4,17 +4,28 @@
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/pion/transport/v3"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
const updateInterval = 30 * time.Second
|
||||
const (
|
||||
updateInterval = 30 * time.Second
|
||||
dnsResolveTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
var errNoSuitableAddress = errors.New("no suitable address found")
|
||||
|
||||
// Net is an implementation of the net.Net interface
|
||||
// based on functions of the standard net package.
|
||||
@@ -28,12 +39,19 @@ type Net struct {
|
||||
|
||||
// mu is shared between interfaces and lastUpdate
|
||||
mu sync.Mutex
|
||||
|
||||
// ctx is the context for network operations that supports cancellation
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewNetWithDiscover creates a new StdNet instance.
|
||||
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
n := &Net{
|
||||
interfaceFilter: InterfaceFilter(disallowList),
|
||||
ctx: ctx,
|
||||
}
|
||||
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
||||
// so in android cli use pionDiscover
|
||||
@@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
|
||||
}
|
||||
|
||||
// NewNet creates a new StdNet instance.
|
||||
func NewNet(disallowList []string) (*Net, error) {
|
||||
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
n := &Net{
|
||||
iFaceDiscover: pionDiscover{},
|
||||
interfaceFilter: InterfaceFilter(disallowList),
|
||||
ctx: ctx,
|
||||
}
|
||||
return n, n.UpdateInterfaces()
|
||||
}
|
||||
|
||||
// resolveAddr performs DNS resolution with context support and timeout.
|
||||
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
|
||||
host, portStr, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
if port < 0 || port > 65535 {
|
||||
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
|
||||
}
|
||||
|
||||
ipNet := "ip"
|
||||
switch network {
|
||||
case "tcp4", "udp4":
|
||||
ipNet = "ip4"
|
||||
case "tcp6", "udp6":
|
||||
ipNet = "ip6"
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if ipNet == "ip6" {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
return netip.AddrPortFrom(addr, uint16(port)), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
|
||||
defer cancel()
|
||||
|
||||
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, err
|
||||
}
|
||||
|
||||
if len(addrs) == 0 {
|
||||
return netip.AddrPort{}, errNoSuitableAddress
|
||||
}
|
||||
|
||||
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
|
||||
}
|
||||
|
||||
// UpdateInterfaces updates the internal list of network interfaces
|
||||
// and associated addresses filtering them by name.
|
||||
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
||||
@@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
|
||||
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
case "":
|
||||
network = "udp"
|
||||
default:
|
||||
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||
}
|
||||
|
||||
addrPort, err := n.resolveAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
|
||||
}
|
||||
|
||||
return net.UDPAddrFromAddrPort(addrPort), nil
|
||||
}
|
||||
|
||||
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
|
||||
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
case "":
|
||||
network = "tcp"
|
||||
default:
|
||||
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||
}
|
||||
|
||||
addrPort, err := n.resolveAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
|
||||
}
|
||||
|
||||
return net.TCPAddrFromAddrPort(addrPort), nil
|
||||
}
|
||||
|
||||
@@ -14,6 +14,9 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -290,7 +293,6 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -311,13 +313,16 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
BIN
client/ui/assets/netbird-disconnected.ico
Normal file
BIN
client/ui/assets/netbird-disconnected.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.9 KiB |
BIN
client/ui/assets/netbird-disconnected.png
Normal file
BIN
client/ui/assets/netbird-disconnected.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.4 KiB |
@@ -85,21 +85,22 @@ func main() {
|
||||
|
||||
// Create the service client (this also builds the settings or networks UI if requested).
|
||||
client := newServiceClient(&newServiceClientArgs{
|
||||
addr: flags.daemonAddr,
|
||||
logFile: logFile,
|
||||
app: a,
|
||||
showSettings: flags.showSettings,
|
||||
showNetworks: flags.showNetworks,
|
||||
showLoginURL: flags.showLoginURL,
|
||||
showDebug: flags.showDebug,
|
||||
showProfiles: flags.showProfiles,
|
||||
addr: flags.daemonAddr,
|
||||
logFile: logFile,
|
||||
app: a,
|
||||
showSettings: flags.showSettings,
|
||||
showNetworks: flags.showNetworks,
|
||||
showLoginURL: flags.showLoginURL,
|
||||
showDebug: flags.showDebug,
|
||||
showProfiles: flags.showProfiles,
|
||||
showQuickActions: flags.showQuickActions,
|
||||
})
|
||||
|
||||
// Watch for theme/settings changes to update the icon.
|
||||
go watchSettingsChanges(a, client)
|
||||
|
||||
// Run in window mode if any UI flag was set.
|
||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles {
|
||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions {
|
||||
a.Run()
|
||||
return
|
||||
}
|
||||
@@ -111,23 +112,29 @@ func main() {
|
||||
return
|
||||
}
|
||||
if running {
|
||||
log.Warnf("another process is running with pid %d, exiting", pid)
|
||||
log.Infof("another process is running with pid %d, sending signal to show window", pid)
|
||||
if err := sendShowWindowSignal(pid); err != nil {
|
||||
log.Errorf("send signal to running instance: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
client.setupSignalHandler(client.ctx)
|
||||
|
||||
client.setDefaultFonts()
|
||||
systray.Run(client.onTrayReady, client.onTrayExit)
|
||||
}
|
||||
|
||||
type cliFlags struct {
|
||||
daemonAddr string
|
||||
showSettings bool
|
||||
showNetworks bool
|
||||
showProfiles bool
|
||||
showDebug bool
|
||||
showLoginURL bool
|
||||
errorMsg string
|
||||
saveLogsInFile bool
|
||||
daemonAddr string
|
||||
showSettings bool
|
||||
showNetworks bool
|
||||
showProfiles bool
|
||||
showDebug bool
|
||||
showLoginURL bool
|
||||
showQuickActions bool
|
||||
errorMsg string
|
||||
saveLogsInFile bool
|
||||
}
|
||||
|
||||
// parseFlags reads and returns all needed command-line flags.
|
||||
@@ -143,6 +150,7 @@ func parseFlags() *cliFlags {
|
||||
flag.BoolVar(&flags.showNetworks, "networks", false, "run networks window")
|
||||
flag.BoolVar(&flags.showProfiles, "profiles", false, "run profiles window")
|
||||
flag.BoolVar(&flags.showDebug, "debug", false, "run debug window")
|
||||
flag.BoolVar(&flags.showQuickActions, "quick-actions", false, "run quick actions window")
|
||||
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
|
||||
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
|
||||
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
|
||||
@@ -158,11 +166,9 @@ func initLogFile() (string, error) {
|
||||
|
||||
// watchSettingsChanges listens for Fyne theme/settings changes and updates the client icon.
|
||||
func watchSettingsChanges(a fyne.App, client *serviceClient) {
|
||||
settingsChangeChan := make(chan fyne.Settings)
|
||||
a.Settings().AddChangeListener(settingsChangeChan)
|
||||
for range settingsChangeChan {
|
||||
a.Settings().AddListener(func(settings fyne.Settings) {
|
||||
client.updateIcon()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// showErrorMessage displays an error message in a simple window.
|
||||
@@ -287,6 +293,7 @@ type serviceClient struct {
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
wQuickActions fyne.Window
|
||||
|
||||
eventManager *event.Manager
|
||||
|
||||
@@ -306,14 +313,15 @@ type menuHandler struct {
|
||||
}
|
||||
|
||||
type newServiceClientArgs struct {
|
||||
addr string
|
||||
logFile string
|
||||
app fyne.App
|
||||
showSettings bool
|
||||
showNetworks bool
|
||||
showDebug bool
|
||||
showLoginURL bool
|
||||
showProfiles bool
|
||||
addr string
|
||||
logFile string
|
||||
app fyne.App
|
||||
showSettings bool
|
||||
showNetworks bool
|
||||
showDebug bool
|
||||
showLoginURL bool
|
||||
showProfiles bool
|
||||
showQuickActions bool
|
||||
}
|
||||
|
||||
// newServiceClient instance constructor
|
||||
@@ -349,6 +357,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
s.showDebugUI()
|
||||
case args.showProfiles:
|
||||
s.showProfilesUI()
|
||||
case args.showQuickActions:
|
||||
s.showQuickActionsUI()
|
||||
}
|
||||
|
||||
return s
|
||||
|
||||
@@ -500,7 +500,7 @@ func (s *serviceClient) createDebugBundleFromCollection(
|
||||
if uploadFailureReason != "" {
|
||||
showUploadFailedDialog(progress.window, localPath, uploadFailureReason)
|
||||
} else {
|
||||
showUploadSuccessDialog(progress.window, localPath, uploadedKey)
|
||||
showUploadSuccessDialog(s.app, progress.window, localPath, uploadedKey)
|
||||
}
|
||||
} else {
|
||||
showBundleCreatedDialog(progress.window, localPath)
|
||||
@@ -565,7 +565,7 @@ func (s *serviceClient) handleDebugCreation(
|
||||
if uploadFailureReason != "" {
|
||||
showUploadFailedDialog(w, localPath, uploadFailureReason)
|
||||
} else {
|
||||
showUploadSuccessDialog(w, localPath, uploadedKey)
|
||||
showUploadSuccessDialog(s.app, w, localPath, uploadedKey)
|
||||
}
|
||||
} else {
|
||||
showBundleCreatedDialog(w, localPath)
|
||||
@@ -665,7 +665,7 @@ func showUploadFailedDialog(w fyne.Window, localPath, failureReason string) {
|
||||
}
|
||||
|
||||
// showUploadSuccessDialog displays a dialog when upload succeeds
|
||||
func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
|
||||
func showUploadSuccessDialog(a fyne.App, w fyne.Window, localPath, uploadedKey string) {
|
||||
log.Infof("Upload key: %s", uploadedKey)
|
||||
keyEntry := widget.NewEntry()
|
||||
keyEntry.SetText(uploadedKey)
|
||||
@@ -683,7 +683,7 @@ func showUploadSuccessDialog(w fyne.Window, localPath, uploadedKey string) {
|
||||
customDialog := dialog.NewCustom("Upload Successful", "OK", content, w)
|
||||
|
||||
copyBtn := createButtonWithAction("Copy key", func() {
|
||||
w.Clipboard().SetContent(uploadedKey)
|
||||
a.Clipboard().SetContent(uploadedKey)
|
||||
log.Info("Upload key copied to clipboard")
|
||||
})
|
||||
|
||||
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
//go:embed assets/netbird.png
|
||||
var iconAbout []byte
|
||||
|
||||
//go:embed assets/netbird-disconnected.png
|
||||
var iconAboutDisconnected []byte
|
||||
|
||||
//go:embed assets/netbird-systemtray-connected.png
|
||||
var iconConnected []byte
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
//go:embed assets/netbird.ico
|
||||
var iconAbout []byte
|
||||
|
||||
//go:embed assets/netbird-disconnected.ico
|
||||
var iconAboutDisconnected []byte
|
||||
|
||||
//go:embed assets/netbird-systemtray-connected.ico
|
||||
var iconConnected []byte
|
||||
|
||||
|
||||
349
client/ui/quickactions.go
Normal file
349
client/ui/quickactions.go
Normal file
@@ -0,0 +1,349 @@
|
||||
//go:build !(linux && 386)
|
||||
|
||||
//go:generate fyne bundle -o quickactions_assets.go assets/connected.png
|
||||
//go:generate fyne bundle -o quickactions_assets.go -append assets/disconnected.png
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/canvas"
|
||||
"fyne.io/fyne/v2/container"
|
||||
"fyne.io/fyne/v2/layout"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
type quickActionsUiState struct {
|
||||
connectionStatus string
|
||||
isToggleButtonEnabled bool
|
||||
isConnectionChanged bool
|
||||
toggleAction func()
|
||||
}
|
||||
|
||||
func newQuickActionsUiState() quickActionsUiState {
|
||||
return quickActionsUiState{
|
||||
connectionStatus: string(internal.StatusIdle),
|
||||
isToggleButtonEnabled: false,
|
||||
isConnectionChanged: false,
|
||||
}
|
||||
}
|
||||
|
||||
type clientConnectionStatusProvider interface {
|
||||
connectionStatus(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
type daemonClientConnectionStatusProvider struct {
|
||||
client proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
func (d daemonClientConnectionStatusProvider) connectionStatus(ctx context.Context) (string, error) {
|
||||
childCtx, cancel := context.WithTimeout(ctx, 400*time.Millisecond)
|
||||
defer cancel()
|
||||
status, err := d.client.Status(childCtx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return status.Status, nil
|
||||
}
|
||||
|
||||
type clientCommand interface {
|
||||
execute() error
|
||||
}
|
||||
|
||||
type connectCommand struct {
|
||||
connectClient func() error
|
||||
}
|
||||
|
||||
func (c connectCommand) execute() error {
|
||||
return c.connectClient()
|
||||
}
|
||||
|
||||
type disconnectCommand struct {
|
||||
disconnectClient func() error
|
||||
}
|
||||
|
||||
func (c disconnectCommand) execute() error {
|
||||
return c.disconnectClient()
|
||||
}
|
||||
|
||||
type quickActionsViewModel struct {
|
||||
provider clientConnectionStatusProvider
|
||||
connect clientCommand
|
||||
disconnect clientCommand
|
||||
uiChan chan quickActionsUiState
|
||||
isWatchingConnectionStatus atomic.Bool
|
||||
}
|
||||
|
||||
func newQuickActionsViewModel(ctx context.Context, provider clientConnectionStatusProvider, connect, disconnect clientCommand, uiChan chan quickActionsUiState) {
|
||||
viewModel := quickActionsViewModel{
|
||||
provider: provider,
|
||||
connect: connect,
|
||||
disconnect: disconnect,
|
||||
uiChan: uiChan,
|
||||
}
|
||||
|
||||
viewModel.isWatchingConnectionStatus.Store(true)
|
||||
|
||||
// base UI status
|
||||
uiChan <- newQuickActionsUiState()
|
||||
|
||||
// this retrieves the current connection status
|
||||
// and pushes the UI state that reflects it via uiChan
|
||||
go viewModel.watchConnectionStatus(ctx)
|
||||
}
|
||||
|
||||
func (q *quickActionsViewModel) updateUiState(ctx context.Context) {
|
||||
uiState := newQuickActionsUiState()
|
||||
connectionStatus, err := q.provider.connectionStatus(ctx)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("Status: Error - %v", err)
|
||||
q.uiChan <- uiState
|
||||
return
|
||||
}
|
||||
|
||||
if connectionStatus == string(internal.StatusConnected) {
|
||||
uiState.toggleAction = func() {
|
||||
q.executeCommand(q.disconnect)
|
||||
}
|
||||
} else {
|
||||
uiState.toggleAction = func() {
|
||||
q.executeCommand(q.connect)
|
||||
}
|
||||
}
|
||||
|
||||
uiState.isToggleButtonEnabled = true
|
||||
uiState.connectionStatus = connectionStatus
|
||||
q.uiChan <- uiState
|
||||
}
|
||||
|
||||
func (q *quickActionsViewModel) watchConnectionStatus(ctx context.Context) {
|
||||
ticker := time.NewTicker(1000 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if q.isWatchingConnectionStatus.Load() {
|
||||
q.updateUiState(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *quickActionsViewModel) executeCommand(command clientCommand) {
|
||||
uiState := newQuickActionsUiState()
|
||||
// newQuickActionsUiState starts with Idle connection status,
|
||||
// and all that's necessary here is to just disable the toggle button.
|
||||
uiState.connectionStatus = ""
|
||||
|
||||
q.uiChan <- uiState
|
||||
|
||||
q.isWatchingConnectionStatus.Store(false)
|
||||
|
||||
err := command.execute()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("Status: Error - %v", err)
|
||||
q.isWatchingConnectionStatus.Store(true)
|
||||
} else {
|
||||
uiState = newQuickActionsUiState()
|
||||
uiState.isConnectionChanged = true
|
||||
q.uiChan <- uiState
|
||||
}
|
||||
}
|
||||
|
||||
func getSystemTrayName() string {
|
||||
os := runtime.GOOS
|
||||
switch os {
|
||||
case "darwin":
|
||||
return "menu bar"
|
||||
default:
|
||||
return "system tray"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) getNetBirdImage(name string, content []byte) *canvas.Image {
|
||||
imageSize := fyne.NewSize(64, 64)
|
||||
|
||||
resource := fyne.NewStaticResource(name, content)
|
||||
image := canvas.NewImageFromResource(resource)
|
||||
image.FillMode = canvas.ImageFillContain
|
||||
image.SetMinSize(imageSize)
|
||||
image.Resize(imageSize)
|
||||
|
||||
return image
|
||||
}
|
||||
|
||||
type quickActionsUiComponents struct {
|
||||
content *fyne.Container
|
||||
toggleConnectionButton *widget.Button
|
||||
connectedLabelText, disconnectedLabelText string
|
||||
connectedImage, disconnectedImage *canvas.Image
|
||||
connectedCircleRes, disconnectedCircleRes fyne.Resource
|
||||
}
|
||||
|
||||
// applyQuickActionsUiState applies a single UI state to the quick actions window.
|
||||
// It closes the window and returns true if the connection status has changed,
|
||||
// in which case the caller should stop processing further states.
|
||||
func (s *serviceClient) applyQuickActionsUiState(
|
||||
uiState quickActionsUiState,
|
||||
components quickActionsUiComponents,
|
||||
) bool {
|
||||
if uiState.isConnectionChanged {
|
||||
fyne.DoAndWait(func() {
|
||||
s.wQuickActions.Close()
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
var logo *canvas.Image
|
||||
var buttonText string
|
||||
var buttonIcon fyne.Resource
|
||||
|
||||
if uiState.connectionStatus == string(internal.StatusConnected) {
|
||||
buttonText = components.connectedLabelText
|
||||
buttonIcon = components.connectedCircleRes
|
||||
logo = components.connectedImage
|
||||
} else if uiState.connectionStatus == string(internal.StatusIdle) {
|
||||
buttonText = components.disconnectedLabelText
|
||||
buttonIcon = components.disconnectedCircleRes
|
||||
logo = components.disconnectedImage
|
||||
}
|
||||
|
||||
fyne.DoAndWait(func() {
|
||||
if buttonText != "" {
|
||||
components.toggleConnectionButton.SetText(buttonText)
|
||||
}
|
||||
|
||||
if buttonIcon != nil {
|
||||
components.toggleConnectionButton.SetIcon(buttonIcon)
|
||||
}
|
||||
|
||||
if uiState.isToggleButtonEnabled {
|
||||
components.toggleConnectionButton.Enable()
|
||||
} else {
|
||||
components.toggleConnectionButton.Disable()
|
||||
}
|
||||
|
||||
components.toggleConnectionButton.OnTapped = func() {
|
||||
if uiState.toggleAction != nil {
|
||||
go uiState.toggleAction()
|
||||
}
|
||||
}
|
||||
|
||||
components.toggleConnectionButton.Refresh()
|
||||
|
||||
// the second position in the content's object array is the NetBird logo.
|
||||
if logo != nil {
|
||||
components.content.Objects[1] = logo
|
||||
components.content.Refresh()
|
||||
}
|
||||
})
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// showQuickActionsUI displays a simple window with the NetBird logo and a connection toggle button.
|
||||
func (s *serviceClient) showQuickActionsUI() {
|
||||
s.wQuickActions = s.app.NewWindow("NetBird")
|
||||
vmCtx, vmCancel := context.WithCancel(s.ctx)
|
||||
s.wQuickActions.SetOnClosed(vmCancel)
|
||||
|
||||
client, err := s.getSrvClient(defaultFailTimeout)
|
||||
|
||||
connCmd := connectCommand{
|
||||
connectClient: func() error {
|
||||
return s.menuUpClick(s.ctx)
|
||||
},
|
||||
}
|
||||
|
||||
disConnCmd := disconnectCommand{
|
||||
disconnectClient: func() error {
|
||||
return s.menuDownClick()
|
||||
},
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("get service client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
uiChan := make(chan quickActionsUiState, 1)
|
||||
newQuickActionsViewModel(vmCtx, daemonClientConnectionStatusProvider{client: client}, connCmd, disConnCmd, uiChan)
|
||||
|
||||
connectedImage := s.getNetBirdImage("netbird.png", iconAbout)
|
||||
disconnectedImage := s.getNetBirdImage("netbird-disconnected.png", iconAboutDisconnected)
|
||||
|
||||
connectedCircle := canvas.NewImageFromResource(resourceConnectedPng)
|
||||
disconnectedCircle := canvas.NewImageFromResource(resourceDisconnectedPng)
|
||||
|
||||
connectedLabelText := "Disconnect"
|
||||
disconnectedLabelText := "Connect"
|
||||
|
||||
toggleConnectionButton := widget.NewButtonWithIcon(disconnectedLabelText, disconnectedCircle.Resource, func() {
|
||||
// This button's tap function will be set when an ui state arrives via the uiChan channel.
|
||||
})
|
||||
|
||||
// Button starts disabled until the first ui state arrives.
|
||||
toggleConnectionButton.Disable()
|
||||
|
||||
hintLabelText := fmt.Sprintf("You can always access NetBird from your %s.", getSystemTrayName())
|
||||
hintLabel := widget.NewLabel(hintLabelText)
|
||||
|
||||
content := container.NewVBox(
|
||||
layout.NewSpacer(),
|
||||
disconnectedImage,
|
||||
layout.NewSpacer(),
|
||||
container.NewCenter(toggleConnectionButton),
|
||||
layout.NewSpacer(),
|
||||
container.NewCenter(hintLabel),
|
||||
)
|
||||
|
||||
// this watches for ui state updates.
|
||||
go func() {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-vmCtx.Done():
|
||||
return
|
||||
case uiState, ok := <-uiChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
closed := s.applyQuickActionsUiState(
|
||||
uiState,
|
||||
quickActionsUiComponents{
|
||||
content,
|
||||
toggleConnectionButton,
|
||||
connectedLabelText, disconnectedLabelText,
|
||||
connectedImage, disconnectedImage,
|
||||
connectedCircle.Resource, disconnectedCircle.Resource,
|
||||
},
|
||||
)
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.wQuickActions.SetContent(content)
|
||||
s.wQuickActions.Resize(fyne.NewSize(400, 200))
|
||||
s.wQuickActions.SetFixedSize(true)
|
||||
s.wQuickActions.Show()
|
||||
}
|
||||
23
client/ui/quickactions_assets.go
Normal file
23
client/ui/quickactions_assets.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// auto-generated
|
||||
// Code generated by '$ fyne bundle'. DO NOT EDIT.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fyne.io/fyne/v2"
|
||||
)
|
||||
|
||||
//go:embed assets/connected.png
|
||||
var resourceConnectedPngData []byte
|
||||
var resourceConnectedPng = &fyne.StaticResource{
|
||||
StaticName: "assets/connected.png",
|
||||
StaticContent: resourceConnectedPngData,
|
||||
}
|
||||
|
||||
//go:embed assets/disconnected.png
|
||||
var resourceDisconnectedPngData []byte
|
||||
var resourceDisconnectedPng = &fyne.StaticResource{
|
||||
StaticName: "assets/disconnected.png",
|
||||
StaticContent: resourceDisconnectedPngData,
|
||||
}
|
||||
76
client/ui/signal_unix.go
Normal file
76
client/ui/signal_unix.go
Normal file
@@ -0,0 +1,76 @@
|
||||
//go:build !windows && !(linux && 386)
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// setupSignalHandler sets up a signal handler to listen for SIGUSR1.
|
||||
// When received, it opens the quick actions window.
|
||||
func (s *serviceClient) setupSignalHandler(ctx context.Context) {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGUSR1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-sigChan:
|
||||
log.Info("received SIGUSR1 signal, opening quick actions window")
|
||||
s.openQuickActions()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// openQuickActions opens the quick actions window by spawning a new process.
|
||||
func (s *serviceClient) openQuickActions() {
|
||||
proc, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Errorf("get executable path: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(s.ctx, proc,
|
||||
"--quick-actions=true",
|
||||
"--daemon-addr="+s.addr,
|
||||
)
|
||||
|
||||
if out := s.attachOutput(cmd); out != nil {
|
||||
defer func() {
|
||||
if err := out.Close(); err != nil {
|
||||
log.Errorf("close log file %s: %v", s.logFile, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
log.Errorf("start quick actions window: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
log.Debugf("quick actions window exited: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// sendShowWindowSignal sends SIGUSR1 to the specified PID.
|
||||
func sendShowWindowSignal(pid int32) error {
|
||||
process, err := os.FindProcess(int(pid))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return process.Signal(syscall.SIGUSR1)
|
||||
}
|
||||
171
client/ui/signal_windows.go
Normal file
171
client/ui/signal_windows.go
Normal file
@@ -0,0 +1,171 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
quickActionsTriggerEventName = `Global\NetBirdQuickActionsTriggerEvent`
|
||||
waitTimeout = 5 * time.Second
|
||||
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
||||
desiredAccesses = windows.SYNCHRONIZE | windows.EVENT_MODIFY_STATE
|
||||
)
|
||||
|
||||
func getEventNameUint16Pointer() (*uint16, error) {
|
||||
eventNamePtr, err := windows.UTF16PtrFromString(quickActionsTriggerEventName)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to convert event name '%s' to UTF16: %v", quickActionsTriggerEventName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return eventNamePtr, nil
|
||||
}
|
||||
|
||||
// setupSignalHandler sets up signal handling for Windows.
|
||||
// Windows doesn't support SIGUSR1, so this uses a similar approach using windows.Events.
|
||||
func (s *serviceClient) setupSignalHandler(ctx context.Context) {
|
||||
eventNamePtr, err := getEventNameUint16Pointer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
||||
log.Warnf("Quick actions trigger event '%s' already exists. Attempting to open.", quickActionsTriggerEventName)
|
||||
eventHandle, err = windows.OpenEvent(desiredAccesses, false, eventNamePtr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to open existing quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
log.Infof("Successfully opened existing quick actions trigger event '%s'.", quickActionsTriggerEventName)
|
||||
} else {
|
||||
log.Errorf("Failed to create quick actions trigger event '%s': %v", quickActionsTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if eventHandle == windows.InvalidHandle {
|
||||
log.Errorf("Obtained an invalid handle for quick actions trigger event '%s'", quickActionsTriggerEventName)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Quick actions handler waiting for signal on event: %s", quickActionsTriggerEventName)
|
||||
|
||||
go s.waitForEvent(ctx, eventHandle)
|
||||
}
|
||||
|
||||
func (s *serviceClient) waitForEvent(ctx context.Context, eventHandle windows.Handle) {
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(eventHandle); err != nil {
|
||||
log.Errorf("Failed to close quick actions event handle '%s': %v", quickActionsTriggerEventName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
||||
|
||||
switch status {
|
||||
case windows.WAIT_OBJECT_0:
|
||||
log.Info("Received signal on quick actions event. Opening quick actions window.")
|
||||
|
||||
// reset the event so it can be triggered again later (manual reset == 1)
|
||||
if err := windows.ResetEvent(eventHandle); err != nil {
|
||||
log.Errorf("Failed to reset quick actions event '%s': %v", quickActionsTriggerEventName, err)
|
||||
}
|
||||
|
||||
s.openQuickActions()
|
||||
case uint32(windows.WAIT_TIMEOUT):
|
||||
|
||||
default:
|
||||
if isDone := logUnexpectedStatus(ctx, status, err); isDone {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func logUnexpectedStatus(ctx context.Context, status uint32, err error) bool {
|
||||
log.Errorf("Unexpected status %d from WaitForSingleObject for quick actions event '%s': %v",
|
||||
status, quickActionsTriggerEventName, err)
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
return false
|
||||
case <-ctx.Done():
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// openQuickActions opens the quick actions window by spawning a new process.
|
||||
func (s *serviceClient) openQuickActions() {
|
||||
proc, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Errorf("get executable path: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(s.ctx, proc,
|
||||
"--quick-actions=true",
|
||||
"--daemon-addr="+s.addr,
|
||||
)
|
||||
|
||||
if out := s.attachOutput(cmd); out != nil {
|
||||
defer func() {
|
||||
if err := out.Close(); err != nil {
|
||||
log.Errorf("close log file %s: %v", s.logFile, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
log.Infof("running command: %s --quick-actions=true --daemon-addr=%s", proc, s.addr)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
log.Errorf("error starting quick actions window: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
log.Debugf("quick actions window exited: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func sendShowWindowSignal(pid int32) error {
|
||||
_, err := os.FindProcess(int(pid))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventNamePtr, err := getEventNameUint16Pointer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventHandle, err := windows.OpenEvent(desiredAccesses, false, eventNamePtr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = windows.SetEvent(eventHandle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error setting event: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
57
go.mod
57
go.mod
@@ -16,7 +16,7 @@ require (
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/vishvananda/netlink v1.3.0
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/crypto v0.40.0
|
||||
golang.org/x/sys v0.34.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
@@ -28,8 +28,8 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.5.3
|
||||
fyne.io/systray v1.11.0
|
||||
fyne.io/fyne/v2 v2.7.0
|
||||
fyne.io/systray v1.11.1-0.20250603113521-ca66a66d8b58
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.29.14
|
||||
@@ -43,7 +43,7 @@ require (
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
@@ -56,12 +56,13 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
@@ -82,7 +83,7 @@ require (
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/testcontainers/testcontainers-go v0.31.0
|
||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
|
||||
@@ -98,15 +99,17 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||
go.opentelemetry.io/otel/metric v1.35.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0
|
||||
go.uber.org/mock v0.5.0
|
||||
go.uber.org/zap v1.27.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/mod v0.25.0
|
||||
golang.org/x/mod v0.26.0
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/oauth2 v0.28.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.16.0
|
||||
golang.org/x/term v0.33.0
|
||||
golang.org/x/time v0.12.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
@@ -123,7 +126,7 @@ require (
|
||||
dario.cat/mergo v1.0.0 // indirect
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||
github.com/BurntSushi/toml v1.4.0 // indirect
|
||||
github.com/BurntSushi/toml v1.5.0 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||
@@ -146,7 +149,7 @@ require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/containerd v1.7.27 // indirect
|
||||
github.com/containerd/containerd v1.7.29 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
@@ -157,11 +160,12 @@ require (
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fredbi/uri v1.1.0 // indirect
|
||||
github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect
|
||||
github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 // indirect
|
||||
github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
|
||||
github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
|
||||
github.com/fredbi/uri v1.1.1 // indirect
|
||||
github.com/fyne-io/gl-js v0.2.0 // indirect
|
||||
github.com/fyne-io/glfw-js v0.3.0 // indirect
|
||||
github.com/fyne-io/image v0.1.1 // indirect
|
||||
github.com/fyne-io/oksvg v0.2.0 // indirect
|
||||
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
@@ -169,7 +173,7 @@ require (
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/go-text/render v0.2.0 // indirect
|
||||
github.com/go-text/typesetting v0.2.0 // indirect
|
||||
github.com/go-text/typesetting v0.2.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
@@ -177,19 +181,19 @@ require (
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||
github.com/gopherjs/gopherjs v1.17.2 // indirect
|
||||
github.com/hack-pad/go-indexeddb v0.3.2 // indirect
|
||||
github.com/hack-pad/safejs v0.1.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
|
||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect
|
||||
github.com/kelseyhightower/envconfig v1.4.0 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
@@ -208,7 +212,8 @@ require (
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
|
||||
github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect
|
||||
github.com/nxadm/tail v1.4.8 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
@@ -224,28 +229,26 @@ require (
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/rymdport/portal v0.3.0 // indirect
|
||||
github.com/rymdport/portal v0.4.2 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.3 // indirect
|
||||
github.com/yuin/goldmark v1.7.1 // indirect
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/image v0.24.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.34.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
||||
|
||||
31
management/internals/controllers/network_map/controller/cache/dns_config_cache.go
vendored
Normal file
31
management/internals/controllers/network_map/controller/cache/dns_config_cache.go
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.NameServerGroups.Load(key); ok {
|
||||
return value.(*proto.NameServerGroup), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetNameServerGroup stores a name server group in the cache
|
||||
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.NameServerGroups.Store(key, value)
|
||||
}
|
||||
@@ -0,0 +1,784 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type Controller struct {
|
||||
repo Repository
|
||||
metrics *metrics
|
||||
// This should not be here, but we need to maintain it for the time being
|
||||
accountManagerMetrics *telemetry.AccountManagerMetrics
|
||||
peersUpdateManager network_map.PeersUpdateManager
|
||||
settingsManager settings.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
sendAccountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||
dnsDomain string
|
||||
|
||||
requestBuffer account.RequestBuffer
|
||||
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
holder *types.Holder
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
mu sync.Mutex
|
||||
next *time.Timer
|
||||
update atomic.Bool
|
||||
}
|
||||
|
||||
var _ network_map.Controller = (*Controller)(nil)
|
||||
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller {
|
||||
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||
}
|
||||
|
||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
expIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
return &Controller{
|
||||
repo: newRepository(store),
|
||||
metrics: nMetrics,
|
||||
accountManagerMetrics: metrics.AccountManagerMetrics(),
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
requestBuffer: requestBuffer,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
settingsManager: settingsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
|
||||
proxyController: proxyController,
|
||||
|
||||
holder: types.NewHolder(),
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
|
||||
hasPeersConnected := false
|
||||
for _, peer := range account.Peers {
|
||||
if c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
hasPeersConnected = true
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if !hasPeersConnected {
|
||||
return nil
|
||||
}
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get validate peers: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return fmt.Errorf("failed to get proxy network maps: %v", err)
|
||||
}
|
||||
|
||||
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get flow enabled status: %v", err)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return fmt.Errorf("recalculate network map cache: %v", err)
|
||||
}
|
||||
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||
if !c.peersUpdateManager.HasChannel(peerId) {
|
||||
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err)
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerId)
|
||||
if peer == nil {
|
||||
return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId)
|
||||
}
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get validated peers: %v", err)
|
||||
}
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
|
||||
return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get extra settings: %v", err)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(peerId)
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peers, err := c.repo.GetAccountPeers(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
|
||||
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: network.CurrentSerial(),
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
DNSConfig: &proto.DNSConfig{
|
||||
ForwarderPort: dnsFwdPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
emptyMap := &types.NetworkMap{
|
||||
Network: network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||
}
|
||||
|
||||
func (c *Controller) getPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
customZone nbdns.CustomZone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
Network: &types.Network{},
|
||||
}
|
||||
}
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
c.enrichAccountFromHolder(account)
|
||||
return account.OnPeerAddedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
c.enrichAccountFromHolder(account)
|
||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||
account := c.getAccountFromHolder(accountId)
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
account.UpdatePeerInNetworkMapCache(peer)
|
||||
}
|
||||
|
||||
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
account.RecalculateNetworkMapCache(validatedPeers)
|
||||
c.updateAccountInHolder(account)
|
||||
}
|
||||
|
||||
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||
return err
|
||||
}
|
||||
c.recalculateNetworkMapCache(account, validatedPeers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
||||
_, ok := c.expNewNetworkMapAIDs[accountId]
|
||||
return c.expNewNetworkMap || ok
|
||||
}
|
||||
|
||||
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
||||
a := c.holder.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
c.holder.AddAccount(account)
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache = a.NetworkMapCache
|
||||
if account.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache.UpdateAccountPointer(account)
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
||||
return c.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account {
|
||||
a := c.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
}
|
||||
|
||||
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
||||
if settings == nil {
|
||||
return c.dnsDomain
|
||||
}
|
||||
if settings.DNSDomain == "" {
|
||||
return c.dnsDomain
|
||||
}
|
||||
|
||||
return settings.DNSDomain
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
|
||||
peerPostureChecks := make(map[string]*posture.Checks)
|
||||
|
||||
if len(account.PostureChecks) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return maps.Values(peerPostureChecks), nil
|
||||
}
|
||||
|
||||
func (c *Controller) StartWarmup(ctx context.Context) {
|
||||
var initialInterval int64
|
||||
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
|
||||
interval, err := strconv.Atoi(intervalStr)
|
||||
if err != nil {
|
||||
initialInterval = 1
|
||||
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
|
||||
} else {
|
||||
initialInterval = int64(interval) * 10
|
||||
go func() {
|
||||
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
|
||||
startupPeriod, err := strconv.Atoi(startupPeriodStr)
|
||||
if err != nil {
|
||||
startupPeriod = 1
|
||||
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
|
||||
}
|
||||
time.Sleep(time.Duration(startupPeriod) * time.Second)
|
||||
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
|
||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
|
||||
}()
|
||||
}
|
||||
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond))
|
||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
|
||||
|
||||
}
|
||||
|
||||
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
|
||||
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
|
||||
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
if len(peers) == 0 {
|
||||
return int64(network_map.OldForwarderPort)
|
||||
}
|
||||
|
||||
reqVer := semver.Canonical(requiredVersion)
|
||||
|
||||
// Check if all peers have the required version or newer
|
||||
for _, peer := range peers {
|
||||
|
||||
// Development version is always supported
|
||||
if peer.Meta.WtVersion == "development" {
|
||||
continue
|
||||
}
|
||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
||||
if peerVersion == "" {
|
||||
// If any peer doesn't have version info, return 0
|
||||
return int64(network_map.OldForwarderPort)
|
||||
}
|
||||
|
||||
// Compare versions
|
||||
if semver.Compare(peerVersion, reqVer) < 0 {
|
||||
return int64(network_map.OldForwarderPort)
|
||||
}
|
||||
}
|
||||
|
||||
// All peers have the required version or newer
|
||||
return int64(network_map.DnsForwarderPort)
|
||||
}
|
||||
|
||||
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isInGroup {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
|
||||
if postureCheck == nil {
|
||||
return errors.New("failed to add policy posture checks: posture checks not found")
|
||||
}
|
||||
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group := account.GetGroup(sourceGroup)
|
||||
if group == nil {
|
||||
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
|
||||
c.UpdatePeerInNetworkMapCache(accountId, peer)
|
||||
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
|
||||
account, err := c.repo.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
groups := make(map[string][]string)
|
||||
for groupID, group := range account.Groups {
|
||||
groups[groupID] = group.Peers
|
||||
}
|
||||
|
||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
return networkMap, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
||||
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
}
|
||||
|
||||
func (c *Controller) IsConnected(peerID string) bool {
|
||||
return c.peersUpdateManager.HasChannel(peerID)
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func TestComputeForwarderPort(t *testing.T) {
|
||||
// Test with empty peers list
|
||||
peers := []*nbpeer.Peer{}
|
||||
result := computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for empty peers list, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have old versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.26.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with old versions, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have new versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.DnsForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with new versions, got %d", network_map.DnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have mixed versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with mixed versions, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have empty version
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with empty version, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "development",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result == int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with dev version, got %d", network_map.DnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have unknown version string
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "unknown",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(network_map.OldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferUpdateAccountPeers(t *testing.T) {
|
||||
const (
|
||||
peersCount = 1000
|
||||
updateAccountInterval = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
|
||||
uapLastRun, dpLastRun atomic.Int64
|
||||
|
||||
totalNewRuns, totalOldRuns int
|
||||
)
|
||||
|
||||
uap := func(ctx context.Context, accountID string) {
|
||||
updatePeersDeleted.Store(deletedPeers.Load())
|
||||
updatePeersRuns.Add(1)
|
||||
uapLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Run("new approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := mu.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
uap(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
b.next = time.AfterFunc(updateAccountInterval, func() {
|
||||
uap(ctx, accountID)
|
||||
})
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalNewRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
|
||||
t.Run("old approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
|
||||
b := mu.(*sync.Mutex)
|
||||
|
||||
if !b.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(updateAccountInterval)
|
||||
b.Unlock()
|
||||
uap(ctx, accountID)
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalOldRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type metrics struct {
|
||||
*telemetry.UpdateChannelMetrics
|
||||
}
|
||||
|
||||
func newMetrics(updateChannelMetrics *telemetry.UpdateChannelMetrics) (*metrics, error) {
|
||||
return &metrics{
|
||||
updateChannelMetrics,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
|
||||
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
var _ Repository = (*repository)(nil)
|
||||
|
||||
func newRepository(s store.Store) Repository {
|
||||
return &repository{
|
||||
store: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) {
|
||||
return r.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) {
|
||||
return r.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
}
|
||||
|
||||
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
|
||||
return r.store.GetAccountByPeerID(ctx, peerID)
|
||||
}
|
||||
39
management/internals/controllers/network_map/interface.go
Normal file
39
management/internals/controllers/network_map/interface.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package network_map
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||
|
||||
DnsForwarderPort = nbdns.ForwarderServerPort
|
||||
OldForwarderPort = nbdns.ForwarderClientPort
|
||||
DnsForwarderPortMinVersion = "v0.59.0"
|
||||
)
|
||||
|
||||
type Controller interface {
|
||||
UpdateAccountPeers(ctx context.Context, accountID string) error
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
|
||||
DeletePeer(ctx context.Context, accountId string, peerId string) error
|
||||
|
||||
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
|
||||
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
|
||||
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
|
||||
DisconnectPeers(ctx context.Context, peerIDs []string)
|
||||
IsConnected(peerID string) bool
|
||||
}
|
||||
225
management/internals/controllers/network_map/interface_mock.go
Normal file
225
management/internals/controllers/network_map/interface_mock.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./interface.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
//
|
||||
|
||||
// Package network_map is a generated GoMock package.
|
||||
package network_map
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
posture "github.com/netbirdio/netbird/management/server/posture"
|
||||
types "github.com/netbirdio/netbird/management/server/types"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockControllerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockControllerMockRecorder is the mock recorder for MockController.
|
||||
type MockControllerMockRecorder struct {
|
||||
mock *MockController
|
||||
}
|
||||
|
||||
// NewMockController creates a new mock instance.
|
||||
func NewMockController(ctrl *gomock.Controller) *MockController {
|
||||
mock := &MockController{ctrl: ctrl}
|
||||
mock.recorder = &MockControllerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers mocks base method.
|
||||
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
|
||||
}
|
||||
|
||||
// DeletePeer mocks base method.
|
||||
func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeletePeer indicates an expected call of DeletePeer.
|
||||
func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId)
|
||||
}
|
||||
|
||||
// DisconnectPeers mocks base method.
|
||||
func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs)
|
||||
}
|
||||
|
||||
// DisconnectPeers indicates an expected call of DisconnectPeers.
|
||||
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs)
|
||||
}
|
||||
|
||||
// GetDNSDomain mocks base method.
|
||||
func (m *MockController) GetDNSDomain(settings *types.Settings) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDNSDomain", settings)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetDNSDomain indicates an expected call of GetDNSDomain.
|
||||
func (mr *MockControllerMockRecorder) GetDNSDomain(settings any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSDomain", reflect.TypeOf((*MockController)(nil).GetDNSDomain), settings)
|
||||
}
|
||||
|
||||
// GetNetworkMap mocks base method.
|
||||
func (m *MockController) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID)
|
||||
ret0, _ := ret[0].(*types.NetworkMap)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetNetworkMap indicates an expected call of GetNetworkMap.
|
||||
func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockController)(nil).GetNetworkMap), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap mocks base method.
|
||||
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMap)
|
||||
ret2, _ := ret[2].([]*posture.Checks)
|
||||
ret3, _ := ret[3].(int64)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||
}
|
||||
|
||||
// IsConnected mocks base method.
|
||||
func (m *MockController) IsConnected(peerID string) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsConnected", peerID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IsConnected indicates an expected call of IsConnected.
|
||||
func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID)
|
||||
}
|
||||
|
||||
// OnPeerAdded mocks base method.
|
||||
func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeerAdded indicates an expected call of OnPeerAdded.
|
||||
func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeerDeleted mocks base method.
|
||||
func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeerDeleted indicates an expected call of OnPeerDeleted.
|
||||
func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeerUpdated mocks base method.
|
||||
func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnPeerUpdated", accountId, peer)
|
||||
}
|
||||
|
||||
// OnPeerUpdated indicates an expected call of OnPeerUpdated.
|
||||
func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer)
|
||||
}
|
||||
|
||||
// StartWarmup mocks base method.
|
||||
func (m *MockController) StartWarmup(arg0 context.Context) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "StartWarmup", arg0)
|
||||
}
|
||||
|
||||
// StartWarmup indicates an expected call of StartWarmup.
|
||||
func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0)
|
||||
}
|
||||
|
||||
// UpdateAccountPeer mocks base method.
|
||||
func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAccountPeer", ctx, accountId, peerId)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAccountPeer indicates an expected call of UpdateAccountPeer.
|
||||
func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeer", reflect.TypeOf((*MockController)(nil).UpdateAccountPeer), ctx, accountId, peerId)
|
||||
}
|
||||
|
||||
// UpdateAccountPeers mocks base method.
|
||||
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
|
||||
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID)
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
package network_map
|
||||
@@ -0,0 +1,13 @@
|
||||
package network_map
|
||||
|
||||
import "context"
|
||||
|
||||
type PeersUpdateManager interface {
|
||||
SendUpdate(ctx context.Context, peerID string, update *UpdateMessage)
|
||||
CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage
|
||||
CloseChannel(ctx context.Context, peerID string)
|
||||
CountStreams() int
|
||||
HasChannel(peerID string) bool
|
||||
CloseChannels(ctx context.Context, peerIDs []string)
|
||||
GetAllConnectedPeers() map[string]struct{}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package update_channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,38 +7,34 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const channelBufferSize = 100
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *types.NetworkMap
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
peerChannels map[string]chan *network_map.UpdateMessage
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
metrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
var _ network_map.PeersUpdateManager = (*PeersUpdateManager)(nil)
|
||||
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerChannels: make(map[string]chan *network_map.UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
// SendUpdate sends update message to the peer's channel
|
||||
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
|
||||
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *network_map.UpdateMessage) {
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
@@ -66,7 +62,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
|
||||
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *network_map.UpdateMessage {
|
||||
start := time.Now()
|
||||
|
||||
closed := false
|
||||
@@ -85,7 +81,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
||||
close(channel)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
channel := make(chan *network_map.UpdateMessage, channelBufferSize)
|
||||
p.peerChannels[peerID] = channel
|
||||
|
||||
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
|
||||
@@ -176,3 +172,9 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
func (p *PeersUpdateManager) CountStreams() int {
|
||||
p.channelsMux.RLock()
|
||||
defer p.channelsMux.RUnlock()
|
||||
return len(p.peerChannels)
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
package server
|
||||
package update_channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -24,7 +25,7 @@ func TestCreateChannel(t *testing.T) {
|
||||
func TestSendUpdate(t *testing.T) {
|
||||
peer := "test-sendupdate"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
update1 := &UpdateMessage{Update: &proto.SyncResponse{
|
||||
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: 0,
|
||||
},
|
||||
@@ -44,7 +45,7 @@ func TestSendUpdate(t *testing.T) {
|
||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||
}
|
||||
|
||||
update2 := &UpdateMessage{Update: &proto.SyncResponse{
|
||||
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: 10,
|
||||
},
|
||||
@@ -0,0 +1,9 @@
|
||||
package network_map
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
@@ -93,7 +93,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator())
|
||||
srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,10 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
@@ -14,9 +18,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
)
|
||||
|
||||
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
|
||||
return Create(s, func() *server.PeersUpdateManager {
|
||||
return server.NewPeersUpdateManager(s.Metrics())
|
||||
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||
return Create(s, func() *update_channel.PeersUpdateManager {
|
||||
return update_channel.NewPeersUpdateManager(s.Metrics())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -40,9 +44,9 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager {
|
||||
return Create(s, func() *server.TimeBasedAuthSecretsManager {
|
||||
return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
|
||||
func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager {
|
||||
return Create(s, func() *grpc.TimeBasedAuthSecretsManager {
|
||||
return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -63,3 +67,15 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||
return manager.NewEphemeralManager(s.Store(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) NetworkMapController() network_map.Controller {
|
||||
return Create(s, func() *nmapcontroller.Controller {
|
||||
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
||||
return Create(s, func() *server.AccountRequestBuffer {
|
||||
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -66,8 +66,7 @@ func (s *BaseServer) PeersManager() peers.Manager {
|
||||
|
||||
func (s *BaseServer) AccountManager() account.Manager {
|
||||
return Create(s, func() account.Manager {
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain,
|
||||
s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
}
|
||||
|
||||
352
management/internals/shared/grpc/conversion.go
Normal file
352
management/internals/shared/grpc/conversion.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var stuns []*proto.HostConfig
|
||||
for _, stun := range config.Stuns {
|
||||
stuns = append(stuns, &proto.HostConfig{
|
||||
Uri: stun.URI,
|
||||
Protocol: ToResponseProto(stun.Proto),
|
||||
})
|
||||
}
|
||||
|
||||
var turns []*proto.ProtectedHostConfig
|
||||
if config.TURNConfig != nil {
|
||||
for _, turn := range config.TURNConfig.Turns {
|
||||
var username string
|
||||
var password string
|
||||
if turnCredentials != nil {
|
||||
username = turnCredentials.Payload
|
||||
password = turnCredentials.Signature
|
||||
} else {
|
||||
username = turn.Username
|
||||
password = turn.Password
|
||||
}
|
||||
turns = append(turns, &proto.ProtectedHostConfig{
|
||||
HostConfig: &proto.HostConfig{
|
||||
Uri: turn.URI,
|
||||
Protocol: ToResponseProto(turn.Proto),
|
||||
},
|
||||
User: username,
|
||||
Password: password,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var relayCfg *proto.RelayConfig
|
||||
if config.Relay != nil && len(config.Relay.Addresses) > 0 {
|
||||
relayCfg = &proto.RelayConfig{
|
||||
Urls: config.Relay.Addresses,
|
||||
}
|
||||
|
||||
if relayToken != nil {
|
||||
relayCfg.TokenPayload = relayToken.Payload
|
||||
relayCfg.TokenSignature = relayToken.Signature
|
||||
}
|
||||
}
|
||||
|
||||
var signalCfg *proto.HostConfig
|
||||
if config.Signal != nil {
|
||||
signalCfg = &proto.HostConfig{
|
||||
Uri: config.Signal.URI,
|
||||
Protocol: ToResponseProto(config.Signal.Proto),
|
||||
}
|
||||
}
|
||||
|
||||
nbConfig := &proto.NetbirdConfig{
|
||||
Stuns: stuns,
|
||||
Turns: turns,
|
||||
Signal: signalCfg,
|
||||
Relay: relayCfg,
|
||||
}
|
||||
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
|
||||
netmask, _ := network.Net.Mask.Size()
|
||||
fqdn := peer.FQDN(dnsName)
|
||||
return &proto.PeerConfig{
|
||||
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
|
||||
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
|
||||
Fqdn: fqdn,
|
||||
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
if networkMap.ForwardingRules != nil {
|
||||
forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
|
||||
for _, rule := range networkMap.ForwardingRules {
|
||||
forwardingRules = append(forwardingRules, rule.ToProto())
|
||||
}
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case nbconfig.UDP:
|
||||
return proto.HostConfig_UDP
|
||||
case nbconfig.DTLS:
|
||||
return proto.HostConfig_DTLS
|
||||
case nbconfig.HTTP:
|
||||
return proto.HostConfig_HTTP
|
||||
case nbconfig.HTTPS:
|
||||
return proto.HostConfig_HTTPS
|
||||
case nbconfig.TCP:
|
||||
return proto.HostConfig_TCP
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||
}
|
||||
}
|
||||
|
||||
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if shouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result[i] = fwRule
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||
func getProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: getProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
PortInfo: getProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoAction converts the action to proto.RuleAction.
|
||||
func getProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
150
management/internals/shared/grpc/conversion_test.go
Normal file
150
management/internals/shared/grpc/conversion_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
var cache cache.DNSConfigCache
|
||||
|
||||
// Create two different configs
|
||||
config1 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group1",
|
||||
Name: "Group 1",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.org",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group2",
|
||||
Name: "Group 2",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||
}
|
||||
|
||||
// Verify that result2 is different from result1 and result3
|
||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000}
|
||||
|
||||
for _, size := range sizes {
|
||||
testData := generateTestData(size)
|
||||
|
||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: make([]nbdns.CustomZone, size),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
config.CustomZones[i] = nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("domain%d.com", i),
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: fmt.Sprintf("record%d", i),
|
||||
Type: 1,
|
||||
Class: "IN",
|
||||
TTL: 3600,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("group%d", i),
|
||||
Primary: i == 0,
|
||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||
SearchDomainsEnabled: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
Port: 53,
|
||||
NSType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@@ -20,7 +22,7 @@ import (
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
|
||||
@@ -44,15 +46,18 @@ import (
|
||||
const (
|
||||
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
|
||||
envBlockPeers = "NB_BLOCK_SAME_PEERS"
|
||||
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
|
||||
|
||||
defaultSyncLim = 1000
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
type GRPCServer struct {
|
||||
// Server an instance of a Management gRPC API server
|
||||
type Server struct {
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
wgKey wgtypes.Key
|
||||
proto.UnimplementedManagementServiceServer
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
peersUpdateManager network_map.PeersUpdateManager
|
||||
config *nbconfig.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
@@ -63,21 +68,28 @@ type GRPCServer struct {
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
loginFilter *loginFilter
|
||||
|
||||
networkMapController network_map.Controller
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
func NewServer(
|
||||
ctx context.Context,
|
||||
config *nbconfig.Config,
|
||||
accountManager account.Manager,
|
||||
settingsManager settings.Manager,
|
||||
peersUpdateManager *PeersUpdateManager,
|
||||
peersUpdateManager network_map.PeersUpdateManager,
|
||||
secretsManager SecretsManager,
|
||||
appMetrics telemetry.AppMetrics,
|
||||
ephemeralManager ephemeral.Manager,
|
||||
authManager auth.Manager,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
) (*GRPCServer, error) {
|
||||
networkMapController network_map.Controller,
|
||||
) (*Server, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -86,7 +98,7 @@ func NewServer(
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||
return int64(len(peersUpdateManager.peerChannels))
|
||||
return int64(peersUpdateManager.CountStreams())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -96,7 +108,18 @@ func NewServer(
|
||||
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
return &GRPCServer{
|
||||
syncLim := int32(defaultSyncLim)
|
||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
|
||||
} else {
|
||||
//nolint:gosec
|
||||
syncLim = int32(syncLimParsed)
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
@@ -110,10 +133,15 @@ func NewServer(
|
||||
logBlockedPeers: logBlockedPeers,
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
networkMapController: networkMapController,
|
||||
|
||||
loginFilter: newLoginFilter(),
|
||||
|
||||
syncLim: syncLim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||
func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||
ip := ""
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if ok {
|
||||
@@ -150,7 +178,12 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
if s.syncSem.Load() >= s.syncLim {
|
||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||
}
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
@@ -158,13 +191,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
syncReq := &proto.SyncRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, syncReq)
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
sRealIP := realIP.String()
|
||||
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
}
|
||||
@@ -172,6 +206,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
|
||||
}
|
||||
}
|
||||
@@ -183,42 +218,54 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
|
||||
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
s.syncSem.Add(-1)
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
start := time.Now()
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
||||
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||
|
||||
if syncReq.GetMeta() == nil {
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -235,13 +282,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
for {
|
||||
select {
|
||||
@@ -275,7 +322,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
@@ -293,7 +340,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
@@ -308,7 +355,7 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
|
||||
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
if s.authManager == nil {
|
||||
return "", status.Errorf(codes.Internal, "missing auth manager")
|
||||
}
|
||||
@@ -342,7 +389,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
||||
return userAuth.UserId, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
@@ -450,7 +497,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
|
||||
@@ -469,7 +516,7 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
|
||||
// In case it is, the login is successful
|
||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||
// In case of the successful registration login is also successful
|
||||
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
reqStart := time.Now()
|
||||
realIP := getRealIP(ctx)
|
||||
sRealIP := realIP.String()
|
||||
@@ -483,7 +530,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
|
||||
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||
}
|
||||
@@ -509,10 +556,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
||||
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
took := time.Since(reqStart)
|
||||
if took > 7*time.Second {
|
||||
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
|
||||
}
|
||||
}()
|
||||
|
||||
if loginReq.GetMeta() == nil {
|
||||
@@ -546,9 +599,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
|
||||
|
||||
// if the login request contains setup key then it is a registration request
|
||||
if loginReq.GetSetupKey() != "" {
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
|
||||
}
|
||||
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||
@@ -557,6 +613,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
|
||||
@@ -569,7 +627,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
var relayToken *Token
|
||||
var err error
|
||||
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
|
||||
@@ -588,7 +646,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
@@ -600,7 +658,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
|
||||
//
|
||||
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
|
||||
// registered or if it uses a setup key to register.
|
||||
func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
userID := ""
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
var err error
|
||||
@@ -620,166 +678,13 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case nbconfig.UDP:
|
||||
return proto.HostConfig_UDP
|
||||
case nbconfig.DTLS:
|
||||
return proto.HostConfig_DTLS
|
||||
case nbconfig.HTTP:
|
||||
return proto.HostConfig_HTTP
|
||||
case nbconfig.HTTPS:
|
||||
return proto.HostConfig_HTTPS
|
||||
case nbconfig.TCP:
|
||||
return proto.HostConfig_TCP
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||
}
|
||||
}
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var stuns []*proto.HostConfig
|
||||
for _, stun := range config.Stuns {
|
||||
stuns = append(stuns, &proto.HostConfig{
|
||||
Uri: stun.URI,
|
||||
Protocol: ToResponseProto(stun.Proto),
|
||||
})
|
||||
}
|
||||
|
||||
var turns []*proto.ProtectedHostConfig
|
||||
if config.TURNConfig != nil {
|
||||
for _, turn := range config.TURNConfig.Turns {
|
||||
var username string
|
||||
var password string
|
||||
if turnCredentials != nil {
|
||||
username = turnCredentials.Payload
|
||||
password = turnCredentials.Signature
|
||||
} else {
|
||||
username = turn.Username
|
||||
password = turn.Password
|
||||
}
|
||||
turns = append(turns, &proto.ProtectedHostConfig{
|
||||
HostConfig: &proto.HostConfig{
|
||||
Uri: turn.URI,
|
||||
Protocol: ToResponseProto(turn.Proto),
|
||||
},
|
||||
User: username,
|
||||
Password: password,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var relayCfg *proto.RelayConfig
|
||||
if config.Relay != nil && len(config.Relay.Addresses) > 0 {
|
||||
relayCfg = &proto.RelayConfig{
|
||||
Urls: config.Relay.Addresses,
|
||||
}
|
||||
|
||||
if relayToken != nil {
|
||||
relayCfg.TokenPayload = relayToken.Payload
|
||||
relayCfg.TokenSignature = relayToken.Signature
|
||||
}
|
||||
}
|
||||
|
||||
var signalCfg *proto.HostConfig
|
||||
if config.Signal != nil {
|
||||
signalCfg = &proto.HostConfig{
|
||||
Uri: config.Signal.URI,
|
||||
Protocol: ToResponseProto(config.Signal.Proto),
|
||||
}
|
||||
}
|
||||
|
||||
nbConfig := &proto.NetbirdConfig{
|
||||
Stuns: stuns,
|
||||
Turns: turns,
|
||||
Signal: signalCfg,
|
||||
Relay: relayCfg,
|
||||
}
|
||||
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
|
||||
netmask, _ := network.Net.Mask.Size()
|
||||
fqdn := peer.FQDN(dnsName)
|
||||
return &proto.PeerConfig{
|
||||
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
|
||||
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
|
||||
Fqdn: fqdn,
|
||||
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
if networkMap.ForwardingRules != nil {
|
||||
forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
|
||||
for _, rule := range networkMap.ForwardingRules {
|
||||
forwardingRules = append(forwardingRules, rule.ToProto())
|
||||
}
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// IsHealthy indicates whether the service is healthy
|
||||
func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
|
||||
func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error {
|
||||
var err error
|
||||
|
||||
var turnToken *Token
|
||||
@@ -803,29 +708,24 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID)
|
||||
peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
// Get all peers in the account for forwarder port computation
|
||||
allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account peers: %w", err)
|
||||
}
|
||||
dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion)
|
||||
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
sendStart := time.Now()
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
||||
@@ -838,7 +738,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
// GetDeviceAuthorizationFlow returns a device authorization flow information
|
||||
// This is used for initiating an Oauth 2 device authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -896,7 +796,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
|
||||
// This is used for initiating an Oauth 2 pkce authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -951,7 +851,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
|
||||
|
||||
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
|
||||
// peer's under the same account of any updates.
|
||||
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
realIP := getRealIP(ctx)
|
||||
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
@@ -976,7 +876,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
||||
start := time.Now()
|
||||
|
||||
106
management/internals/shared/grpc/server_test.go
Normal file
106
management/internals/shared/grpc/server_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testingClientKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputFlow *config.DeviceAuthorizationFlow
|
||||
expectedFlow *mgmtProto.DeviceAuthorizationFlow
|
||||
expectedErrFunc require.ErrorAssertionFunc
|
||||
expectedErrMSG string
|
||||
expectedComparisonFunc require.ComparisonAssertionFunc
|
||||
expectedComparisonMSG string
|
||||
}{
|
||||
{
|
||||
name: "Testing No Device Flow Config",
|
||||
inputFlow: nil,
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Invalid Device Flow Provider Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "NoNe",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Full Device Flow Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "hosted",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
|
||||
Provider: 0,
|
||||
ProviderConfig: &mgmtProto.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.NoError,
|
||||
expectedErrMSG: "should not return error",
|
||||
expectedComparisonFunc: require.Equal,
|
||||
expectedComparisonMSG: "should match",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
mgmtServer := &Server{
|
||||
wgKey: testingServerKey,
|
||||
config: &config.Config{
|
||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
||||
},
|
||||
}
|
||||
|
||||
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
||||
|
||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
|
||||
require.NoError(t, err, "should be able to encrypt message")
|
||||
|
||||
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
||||
context.TODO(),
|
||||
&mgmtProto.EncryptedMessage{
|
||||
WgPubKey: testingClientKey.PublicKey().String(),
|
||||
Body: encryptedMSG,
|
||||
},
|
||||
)
|
||||
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
|
||||
if testCase.expectedComparisonFunc != nil {
|
||||
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
||||
|
||||
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
||||
require.NoError(t, err, "should be able to decrypt")
|
||||
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -37,7 +38,7 @@ type TimeBasedAuthSecretsManager struct {
|
||||
relayCfg *nbconfig.Relay
|
||||
turnHmacToken *auth.TimedHMAC
|
||||
relayHmacToken *authv2.Generator
|
||||
updateManager *PeersUpdateManager
|
||||
updateManager network_map.PeersUpdateManager
|
||||
settingsManager settings.Manager
|
||||
groupsManager groups.Manager
|
||||
turnCancelMap map[string]chan struct{}
|
||||
@@ -46,7 +47,7 @@ type TimeBasedAuthSecretsManager struct {
|
||||
|
||||
type Token auth.Token
|
||||
|
||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
|
||||
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
|
||||
mgr := &TimeBasedAuthSecretsManager{
|
||||
updateManager: updateManager,
|
||||
turnCfg: turnCfg,
|
||||
@@ -227,7 +228,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
|
||||
@@ -251,7 +252,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
|
||||
m.extendNetbirdConfig(ctx, peerID, accountID, update)
|
||||
|
||||
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
|
||||
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
|
||||
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -31,7 +33,7 @@ var TurnTestHost = &config.Host{
|
||||
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
secret := "some_secret"
|
||||
peersManager := NewPeersUpdateManager(nil)
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
|
||||
rc := &config.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
@@ -80,7 +82,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
ttl := util.Duration{Duration: 2 * time.Second}
|
||||
secret := "some_secret"
|
||||
peersManager := NewPeersUpdateManager(nil)
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
peer := "some_peer"
|
||||
updateChannel := peersManager.CreateChannel(context.Background(), peer)
|
||||
|
||||
@@ -116,7 +118,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
|
||||
}
|
||||
|
||||
var updates []*UpdateMessage
|
||||
var updates []*network_map.UpdateMessage
|
||||
|
||||
loop:
|
||||
for timeout := time.After(5 * time.Second); ; {
|
||||
@@ -185,7 +187,7 @@ loop:
|
||||
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
secret := "some_secret"
|
||||
peersManager := NewPeersUpdateManager(nil)
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
peer := "some_peer"
|
||||
|
||||
rc := &config.Relay{
|
||||
@@ -1,11 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
"log"
|
||||
"net/http"
|
||||
// nolint:gosec
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -11,10 +11,8 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cacheStore "github.com/eko/gocache/lib/v4/store"
|
||||
@@ -26,6 +24,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
@@ -68,7 +67,7 @@ type DefaultAccountManager struct {
|
||||
cacheMux sync.Mutex
|
||||
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
|
||||
cacheLoading map[string]chan struct{}
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
networkMapController network_map.Controller
|
||||
idpManager idp.Manager
|
||||
cacheManager *nbcache.AccountUserDataCache
|
||||
externalCacheManager nbcache.UserDataCache
|
||||
@@ -88,8 +87,7 @@ type DefaultAccountManager struct {
|
||||
singleAccountMode bool
|
||||
// singleAccountModeDomain is a domain to use in singleAccountMode setup
|
||||
singleAccountModeDomain string
|
||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||
dnsDomain string
|
||||
|
||||
peerLoginExpiry Scheduler
|
||||
|
||||
peerInactivityExpiry Scheduler
|
||||
@@ -103,14 +101,11 @@ type DefaultAccountManager struct {
|
||||
|
||||
permissionsManager permissions.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
|
||||
loginFilter *loginFilter
|
||||
|
||||
disableDefaultPolicy bool
|
||||
}
|
||||
|
||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
switch {
|
||||
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
||||
@@ -177,10 +172,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []
|
||||
func BuildManager(
|
||||
ctx context.Context,
|
||||
store store.Store,
|
||||
peersUpdateManager *PeersUpdateManager,
|
||||
networkMapController network_map.Controller,
|
||||
idpManager idp.Manager,
|
||||
singleAccountModeDomain string,
|
||||
dnsDomain string,
|
||||
eventStore activity.Store,
|
||||
geo geolocation.Geolocation,
|
||||
userDeleteFromIDPEnabled bool,
|
||||
@@ -199,12 +193,11 @@ func BuildManager(
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
geo: geo,
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
networkMapController: networkMapController,
|
||||
idpManager: idpManager,
|
||||
ctx: context.Background(),
|
||||
cacheMux: sync.Mutex{},
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
dnsDomain: dnsDomain,
|
||||
eventStore: eventStore,
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
peerInactivityExpiry: NewDefaultScheduler(),
|
||||
@@ -215,11 +208,10 @@ func BuildManager(
|
||||
proxyController: proxyController,
|
||||
settingsManager: settingsManager,
|
||||
permissionsManager: permissionsManager,
|
||||
loginFilter: newLoginFilter(),
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
am.networkMapController.StartWarmup(ctx)
|
||||
|
||||
accountsCounter, err := store.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
@@ -267,32 +259,6 @@ func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
|
||||
am.ephemeralManager = em
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) startWarmup(ctx context.Context) {
|
||||
var initialInterval int64
|
||||
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
|
||||
interval, err := strconv.Atoi(intervalStr)
|
||||
if err != nil {
|
||||
initialInterval = 1
|
||||
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
|
||||
} else {
|
||||
initialInterval = int64(interval) * 10
|
||||
go func() {
|
||||
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
|
||||
startupPeriod, err := strconv.Atoi(startupPeriodStr)
|
||||
if err != nil {
|
||||
startupPeriod = 1
|
||||
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
|
||||
}
|
||||
time.Sleep(time.Duration(startupPeriod) * time.Second)
|
||||
am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
|
||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
|
||||
}()
|
||||
}
|
||||
am.updateAccountPeersBufferInterval.Store(initialInterval)
|
||||
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
|
||||
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
|
||||
return am.externalCacheManager
|
||||
}
|
||||
@@ -1636,19 +1602,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool {
|
||||
return am.loginFilter.allowLogin(wgPubKey, metahash)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start))
|
||||
}()
|
||||
|
||||
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
|
||||
@@ -1656,10 +1613,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
metahash := metaHash(meta, realIP.String())
|
||||
am.loginFilter.addLogin(peerPubKey, metahash)
|
||||
|
||||
return peer, netMap, postureChecks, nil
|
||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
||||
@@ -1676,41 +1630,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
if err != nil {
|
||||
return mapError(ctx, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
|
||||
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
||||
return am.peersUpdateManager.GetAllConnectedPeers(), nil
|
||||
}
|
||||
|
||||
// HasConnectedChannel returns true if peers has channel in update manager, otherwise false
|
||||
func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool {
|
||||
return am.peersUpdateManager.HasChannel(peerID)
|
||||
}
|
||||
|
||||
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
|
||||
|
||||
func isDomainValid(domain string) bool {
|
||||
return invalidDomainRegexp.MatchString(domain)
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
|
||||
if settings == nil {
|
||||
return am.dnsDomain
|
||||
}
|
||||
if settings.DNSDomain == "" {
|
||||
return am.dnsDomain
|
||||
}
|
||||
|
||||
return settings.DNSDomain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {
|
||||
peers := []*nbpeer.Peer{}
|
||||
log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID)
|
||||
@@ -2129,7 +2061,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateNetworkMap {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.networkMapController.OnPeerUpdated(peer.AccountID, peer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2177,7 +2113,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account settings: %w", err)
|
||||
}
|
||||
dnsDomain := am.GetDNSDomain(settings)
|
||||
dnsDomain := am.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
eventMeta := peer.EventMeta(dnsDomain)
|
||||
oldIP := peer.IP.String()
|
||||
|
||||
@@ -89,7 +89,6 @@ type Manager interface {
|
||||
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
|
||||
@@ -97,10 +96,8 @@ type Manager interface {
|
||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||
HasConnectedChannel(peerID string) bool
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
|
||||
@@ -110,7 +107,7 @@ type Manager interface {
|
||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
@@ -127,5 +124,4 @@ type Manager interface {
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
SetEphemeralManager(em ephemeral.Manager)
|
||||
AllowSync(string, uint64) bool
|
||||
}
|
||||
|
||||
11
management/server/account/request_buffer.go
Normal file
11
management/server/account/request_buffer.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type RequestBuffer interface {
|
||||
GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error)
|
||||
}
|
||||
@@ -22,6 +22,9 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
@@ -406,7 +409,7 @@ func TestNewAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -603,7 +606,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
|
||||
@@ -644,7 +647,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, false)
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
@@ -705,7 +708,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -731,7 +734,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -768,7 +771,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -805,7 +808,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -843,7 +846,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -924,7 +927,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
DomainCategory: types.PublicCategory,
|
||||
}
|
||||
|
||||
am, err := createManager(b)
|
||||
am, _, err := createManager(b)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
return
|
||||
@@ -1016,7 +1019,7 @@ func genUsers(p string, n int) map[string]*types.User {
|
||||
}
|
||||
|
||||
func TestAccountManager_AddPeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1086,7 +1089,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1154,8 +1157,17 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
ID: "groupA",
|
||||
@@ -1181,8 +1193,8 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
@@ -1205,11 +1217,20 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
// Ensure that we do not receive an update message before the policy is deleted
|
||||
time.Sleep(time.Second)
|
||||
@@ -1239,8 +1260,17 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
AccountID: account.Id,
|
||||
@@ -1253,8 +1283,8 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
@@ -1288,8 +1318,17 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
ID: "groupA",
|
||||
@@ -1318,8 +1357,11 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
// We need to sleep to wait for the buffer peer update
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
@@ -1341,11 +1383,20 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
@@ -1377,6 +1428,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
for drained := false; !drained; {
|
||||
select {
|
||||
case <-updMsg:
|
||||
default:
|
||||
drained = true
|
||||
}
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1404,7 +1463,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1485,7 +1544,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy
|
||||
}
|
||||
|
||||
func TestGetUsersFromAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1736,7 +1795,9 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Address: "172.12.6.1/24",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
account.InitOnce()
|
||||
err := hasNilField(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1782,7 +1843,7 @@ func hasNilField(x interface{}) error {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -1797,7 +1858,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -1853,7 +1914,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -1896,7 +1957,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -1958,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -2622,7 +2683,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
|
||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
// create a new account
|
||||
@@ -2864,18 +2925,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
// Fatalf(format string, args ...interface{})
|
||||
// }
|
||||
|
||||
func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
|
||||
t.Helper()
|
||||
|
||||
store, err := createStore(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
@@ -2893,12 +2954,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
ctx := context.Background()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
|
||||
func createStore(t testing.TB) (store.Store, error) {
|
||||
@@ -2927,10 +2993,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
|
||||
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
|
||||
t.Helper()
|
||||
|
||||
manager, err := createManager(t)
|
||||
manager, updateManager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -2971,10 +3037,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account,
|
||||
peer2 := getPeer(manager, setupKey)
|
||||
peer3 := getPeer(manager, setupKey)
|
||||
|
||||
return manager, account, peer1, peer2, peer3
|
||||
return manager, updateManager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
@@ -2984,7 +3050,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
@@ -3022,7 +3088,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
@@ -3031,16 +3097,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
updateManager.CreateChannel(ctx, peerID)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||
assert.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -3085,7 +3149,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
@@ -3094,11 +3158,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
updateManager.CreateChannel(ctx, peerID)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
@@ -3155,7 +3218,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
@@ -3164,11 +3227,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
updateManager.CreateChannel(ctx, peerID)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
@@ -3227,7 +3289,7 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -3273,7 +3335,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -3303,7 +3365,7 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("memory cache", func(t *testing.T) {
|
||||
@@ -3353,7 +3415,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -3470,7 +3532,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
@@ -3502,7 +3564,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
@@ -3541,7 +3603,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -3608,7 +3670,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -3654,7 +3716,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -3,54 +3,23 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsForwarderPort = nbdns.ForwarderServerPort
|
||||
oldForwarderPort = nbdns.ForwarderClientPort
|
||||
)
|
||||
|
||||
const dnsForwarderPortMinVersion = "v0.59.0"
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.NameServerGroups.Load(key); ok {
|
||||
return value.(*proto.NameServerGroup), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetNameServerGroup stores a name server group in the cache
|
||||
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.NameServerGroups.Store(key, value)
|
||||
}
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
|
||||
@@ -191,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
|
||||
|
||||
return validateGroups(settings.DisabledManagementGroups, groups)
|
||||
}
|
||||
|
||||
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
|
||||
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
|
||||
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
if len(peers) == 0 {
|
||||
return int64(oldForwarderPort)
|
||||
}
|
||||
|
||||
reqVer := semver.Canonical(requiredVersion)
|
||||
|
||||
// Check if all peers have the required version or newer
|
||||
for _, peer := range peers {
|
||||
|
||||
// Development version is always supported
|
||||
if peer.Meta.WtVersion == "development" {
|
||||
continue
|
||||
}
|
||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
||||
if peerVersion == "" {
|
||||
// If any peer doesn't have version info, return 0
|
||||
return int64(oldForwarderPort)
|
||||
}
|
||||
|
||||
// Compare versions
|
||||
if semver.Compare(peerVersion, reqVer) < 0 {
|
||||
return int64(oldForwarderPort)
|
||||
}
|
||||
}
|
||||
|
||||
// All peers have the required version or newer
|
||||
return int64(dnsForwarderPort)
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
@@ -2,9 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,6 +10,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -218,7 +218,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
// return empty extra settings for expected calls to UpdateAccountPeers
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock())
|
||||
|
||||
return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -344,247 +350,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: make([]nbdns.CustomZone, size),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
config.CustomZones[i] = nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("domain%d.com", i),
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: fmt.Sprintf("record%d", i),
|
||||
Type: 1,
|
||||
Class: "IN",
|
||||
TTL: 3600,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("group%d", i),
|
||||
Primary: i == 0,
|
||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||
SearchDomainsEnabled: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
Port: 53,
|
||||
NSType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000}
|
||||
|
||||
for _, size := range sizes {
|
||||
testData := generateTestData(size)
|
||||
|
||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||
cache := &DNSConfigCache{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
var cache DNSConfigCache
|
||||
|
||||
// Create two different configs
|
||||
config1 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group1",
|
||||
Name: "Group 1",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.org",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group2",
|
||||
Name: "Group 2",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||
}
|
||||
|
||||
// Verify that result2 is different from result1 and result3
|
||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeForwarderPort(t *testing.T) {
|
||||
// Test with empty peers list
|
||||
peers := []*nbpeer.Peer{}
|
||||
result := computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have old versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.26.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have new versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(dnsForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have mixed versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have empty version
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "development",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result == int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have unknown version string
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "unknown",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != int64(oldForwarderPort) {
|
||||
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
{
|
||||
@@ -600,9 +367,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
|
||||
|
||||
@@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetEvents(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -138,6 +138,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
|
||||
@@ -157,11 +162,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -335,6 +335,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
|
||||
if oldGroup.Name != newGroup.Name {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"old_name": oldGroup.Name,
|
||||
"new_name": newGroup.Name,
|
||||
}
|
||||
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
addedPeers = append(addedPeers, newGroup.Peers...)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
@@ -354,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
|
||||
return nil
|
||||
}
|
||||
dnsDomain := am.GetDNSDomain(settings)
|
||||
dnsDomain := am.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
for _, peerID := range addedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
|
||||
@@ -37,7 +37,7 @@ const (
|
||||
)
|
||||
|
||||
func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
am, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
am, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create account manager: %s", err)
|
||||
}
|
||||
@@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
am, _, err := createManager(t)
|
||||
assert.NoError(t, err, "Failed to create account manager")
|
||||
|
||||
manager, account, err := initTestGroupAccount(am)
|
||||
@@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
}
|
||||
|
||||
func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
g := []*types.Group{
|
||||
{
|
||||
@@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving a group that is not linked to any resource should not update account peers
|
||||
@@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_AddPeerToGroup(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_AddPeerToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_AddPeerAndAddToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP {
|
||||
}
|
||||
|
||||
func Test_IncrementNetworkSerial(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
||||
@@ -4,11 +4,16 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -38,7 +43,12 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
const (
|
||||
apiPrefix = "/api"
|
||||
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
|
||||
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
|
||||
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(
|
||||
@@ -56,13 +66,45 @@ func NewAPIHandler(
|
||||
permissionsManager permissions.Manager,
|
||||
peersManager nbpeers.Manager,
|
||||
settingsManager settings.Manager,
|
||||
networkMapController network_map.Controller,
|
||||
) (http.Handler, error) {
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
rpm := 6
|
||||
if v := os.Getenv(rateLimitingRPMKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
|
||||
burst := 500
|
||||
if v := os.Getenv(rateLimitingBurstKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
|
||||
rateLimitingConfig = &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
authManager,
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimitingConfig,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
@@ -80,7 +122,7 @@ func NewAPIHandler(
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||
peers.AddEndpoints(accountManager, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
@@ -23,11 +24,12 @@ import (
|
||||
|
||||
// Handler is a handler that returns peers of the account
|
||||
type Handler struct {
|
||||
accountManager account.Manager
|
||||
accountManager account.Manager
|
||||
networkMapController network_map.Controller
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
peersHandler := NewHandler(accountManager)
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
|
||||
peersHandler := NewHandler(accountManager, networkMapController)
|
||||
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
@@ -36,9 +38,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
}
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
func NewHandler(accountManager account.Manager) *Handler {
|
||||
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
|
||||
return &Handler{
|
||||
accountManager: accountManager,
|
||||
accountManager: accountManager,
|
||||
networkMapController: networkMapController,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +50,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
if peer.Status.Connected {
|
||||
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
|
||||
// This may happen after server restart when not all peers are yet connected
|
||||
if !h.accountManager.HasConnectedChannel(peer.ID) {
|
||||
if !h.networkMapController.IsConnected(peer.ID) {
|
||||
peerToReturn.Status.Connected = false
|
||||
}
|
||||
}
|
||||
@@ -73,7 +76,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
||||
return
|
||||
}
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain(settings)
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
||||
@@ -139,7 +142,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain(settings)
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
@@ -227,7 +230,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain(settings)
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
|
||||
@@ -317,7 +320,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain(account.Settings)
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
|
||||
@@ -14,12 +14,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -36,7 +38,7 @@ const (
|
||||
serviceUser = "service_user"
|
||||
)
|
||||
|
||||
func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
@@ -99,6 +101,22 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
},
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
networkMapController := network_map.NewMockController(ctrl)
|
||||
networkMapController.EXPECT().
|
||||
GetDNSDomain(gomock.Any()).
|
||||
Return("domain").
|
||||
AnyTimes()
|
||||
networkMapController.EXPECT().
|
||||
IsConnected(noUpdateChannelTestPeerID).
|
||||
Return(false).
|
||||
AnyTimes()
|
||||
networkMapController.EXPECT().
|
||||
IsConnected(gomock.Any()).
|
||||
Return(true).
|
||||
AnyTimes()
|
||||
|
||||
return &Handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
@@ -187,6 +205,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
return account.Settings, nil
|
||||
},
|
||||
},
|
||||
networkMapController: networkMapController,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,7 +289,7 @@ func TestGetPeers(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
p := initTestMetaData(peer, peer1)
|
||||
p := initTestMetaData(t, peer, peer1)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -374,7 +393,7 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
UserID: regularUser,
|
||||
}
|
||||
|
||||
p := initTestMetaData(peer1, peer2, peer3)
|
||||
p := initTestMetaData(t, peer1, peer2, peer3)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
@@ -477,7 +496,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initTestMetaData(testPeer)
|
||||
p := initTestMetaData(t, testPeer)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
|
||||
@@ -29,6 +29,7 @@ type AuthMiddleware struct {
|
||||
ensureAccount EnsureAccountFunc
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -37,12 +38,19 @@ func NewAuthMiddleware(
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
authManager: authManager,
|
||||
ensureAccount: ensureAccount,
|
||||
syncUserJWTGroups: syncUserJWTGroups,
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
err = status.Errorf(status.Unauthorized, "token invalid")
|
||||
}
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, request)
|
||||
@@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,7 +27,9 @@ const (
|
||||
domainCategory = "domainCategory"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
tokenID2 = "tokenID2"
|
||||
PAT = "nbp_PAT"
|
||||
PAT2 = "nbp_PAT2"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
)
|
||||
@@ -49,6 +51,15 @@ var testAccount = &types.Account{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||
},
|
||||
tokenID2: {
|
||||
ID: tokenID2,
|
||||
Name: "My second token",
|
||||
HashedToken: "someHash2",
|
||||
ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
|
||||
CreatedBy: userID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
|
||||
if token == PAT {
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
if token == PAT2 {
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
@@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
if token == tokenID {
|
||||
if token == tokenID || token == tokenID2 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Should never get reached")
|
||||
@@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
mockAuth := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: mockMarkPATUsed,
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) {
|
||||
// Configure rate limiter: 10 requests per minute with burst of 5
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 10,
|
||||
Burst: 5,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make burst requests - all should succeed
|
||||
successCount := 0
|
||||
for i := 0; i < 5; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 5, successCount, "All burst requests should succeed")
|
||||
|
||||
// The 6th request should fail (exceeded burst)
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited")
|
||||
})
|
||||
|
||||
t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) {
|
||||
// Configure very low rate limit: 1 request per minute
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request should succeed
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
// Second request should fail (rate limited)
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
})
|
||||
|
||||
t.Run("Bearer Token Not Rate Limited", func(t *testing.T) {
|
||||
// Configure strict rate limit
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make multiple requests with Bearer token - all should succeed
|
||||
successCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+JWT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)")
|
||||
})
|
||||
|
||||
t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) {
|
||||
// Configure rate limiter
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Use first PAT token
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed")
|
||||
|
||||
// Second request with same token should fail
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited")
|
||||
|
||||
// Use second PAT token - should succeed because it has independent rate limit
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT2)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)")
|
||||
|
||||
// Second request with PAT2 should also be rate limited
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT2)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited")
|
||||
|
||||
// JWT should still work (not rate limited)
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+JWT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)")
|
||||
})
|
||||
|
||||
t.Run("Rate Limiter Cleanup", func(t *testing.T) {
|
||||
// Configure rate limiter with short cleanup interval and TTL for testing
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
LimiterTTL: 200 * time.Millisecond,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request - should succeed
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
// Second request immediately - should fail (burst exhausted)
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
|
||||
// Wait for limiter to be cleaned up (TTL + cleanup interval + buffer)
|
||||
time.Sleep(400 * time.Millisecond)
|
||||
|
||||
// After cleanup, the limiter should be removed and recreated with full burst capacity
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)")
|
||||
|
||||
// Verify it's a fresh limiter by checking burst is reset
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
@@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
146
management/server/http/middleware/rate_limiter.go
Normal file
146
management/server/http/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||
RequestsPerMinute float64
|
||||
// Burst defines the maximum number of requests that can be made in a burst
|
||||
Burst int
|
||||
// CleanupInterval defines how often to clean up old limiters (how often garbage collection runs)
|
||||
CleanupInterval time.Duration
|
||||
// LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal)
|
||||
LimiterTTL time.Duration
|
||||
}
|
||||
|
||||
// DefaultRateLimiterConfig returns a default configuration
|
||||
func DefaultRateLimiterConfig() *RateLimiterConfig {
|
||||
return &RateLimiterConfig{
|
||||
RequestsPerMinute: 100,
|
||||
Burst: 120,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// limiterEntry holds a rate limiter and its last access time
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
lastAccess time.Time
|
||||
}
|
||||
|
||||
// APIRateLimiter manages rate limiting for API tokens
|
||||
type APIRateLimiter struct {
|
||||
config *RateLimiterConfig
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||
func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
||||
if config == nil {
|
||||
config = DefaultRateLimiterConfig()
|
||||
}
|
||||
|
||||
rl := &APIRateLimiter{
|
||||
config: config,
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key (token) is allowed
|
||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Allow()
|
||||
}
|
||||
|
||||
// Wait blocks until the rate limiter allows another request for the given key
|
||||
// Returns an error if the context is canceled
|
||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Wait(ctx)
|
||||
}
|
||||
|
||||
// getLimiter retrieves or creates a rate limiter for the given key
|
||||
func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter {
|
||||
rl.mu.RLock()
|
||||
entry, exists := rl.limiters[key]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
rl.mu.Lock()
|
||||
entry.lastAccess = time.Now()
|
||||
rl.mu.Unlock()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if entry, exists := rl.limiters[key]; exists {
|
||||
entry.lastAccess = time.Now()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
requestsPerSecond := rl.config.RequestsPerMinute / 60.0
|
||||
limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst)
|
||||
rl.limiters[key] = &limiterEntry{
|
||||
limiter: limiter,
|
||||
lastAccess: time.Now(),
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupLoop periodically removes old limiters that haven't been used recently
|
||||
func (rl *APIRateLimiter) cleanupLoop() {
|
||||
ticker := time.NewTicker(rl.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
rl.cleanup()
|
||||
case <-rl.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes limiters that haven't been used within the TTL period
|
||||
func (rl *APIRateLimiter) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, entry := range rl.limiters {
|
||||
if now.Sub(entry.lastAccess) > rl.config.LimiterTTL {
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (rl *APIRateLimiter) Stop() {
|
||||
close(rl.stopChan)
|
||||
}
|
||||
|
||||
// Reset removes the rate limiter for a specific key
|
||||
func (rl *APIRateLimiter) Reset(key string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
@@ -10,6 +10,10 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -31,7 +35,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
)
|
||||
|
||||
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
|
||||
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test store: %v", err)
|
||||
@@ -43,7 +47,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
t.Fatalf("Failed to create metrics: %v", err)
|
||||
}
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(nil)
|
||||
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
|
||||
done := make(chan struct{})
|
||||
if validateUpdate {
|
||||
@@ -63,7 +67,11 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
|
||||
ctx := context.Background()
|
||||
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock())
|
||||
am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
@@ -83,7 +91,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -91,7 +99,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
return apiHandler, am, done
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) {
|
||||
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
@@ -101,7 +109,7 @@ func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
|
||||
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage, expected *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
|
||||
@@ -22,10 +22,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -321,99 +325,6 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testingClientKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputFlow *config.DeviceAuthorizationFlow
|
||||
expectedFlow *mgmtProto.DeviceAuthorizationFlow
|
||||
expectedErrFunc require.ErrorAssertionFunc
|
||||
expectedErrMSG string
|
||||
expectedComparisonFunc require.ComparisonAssertionFunc
|
||||
expectedComparisonMSG string
|
||||
}{
|
||||
{
|
||||
name: "Testing No Device Flow Config",
|
||||
inputFlow: nil,
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Invalid Device Flow Provider Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "NoNe",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Full Device Flow Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "hosted",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
|
||||
Provider: 0,
|
||||
ProviderConfig: &mgmtProto.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.NoError,
|
||||
expectedErrMSG: "should not return error",
|
||||
expectedComparisonFunc: require.Equal,
|
||||
expectedComparisonMSG: "should match",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
mgmtServer := &GRPCServer{
|
||||
wgKey: testingServerKey,
|
||||
config: &config.Config{
|
||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
||||
},
|
||||
}
|
||||
|
||||
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
||||
|
||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
|
||||
require.NoError(t, err, "should be able to encrypt message")
|
||||
|
||||
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
||||
context.TODO(),
|
||||
&mgmtProto.EncryptedMessage{
|
||||
WgPubKey: testingClientKey.PublicKey().String(),
|
||||
Body: encryptedMSG,
|
||||
},
|
||||
)
|
||||
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
|
||||
if testCase.expectedComparisonFunc != nil {
|
||||
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
||||
|
||||
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
||||
require.NoError(t, err, "should be able to decrypt")
|
||||
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
@@ -427,7 +338,6 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
peersUpdateManager := NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
@@ -451,7 +361,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := BuildManager(ctx, store, networkMapController, nil, "",
|
||||
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
if err != nil {
|
||||
@@ -459,10 +372,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
|
||||
ephemeralMgr := manager.NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
@@ -764,9 +677,38 @@ func Test_LoginPerformance(t *testing.T) {
|
||||
peerLogin := types.PeerLogin{
|
||||
WireGuardPubKey: key.String(),
|
||||
SSHKey: "random",
|
||||
Meta: extractPeerMeta(context.Background(), meta),
|
||||
SetupKey: setupKey.Key,
|
||||
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: meta.GetHostname(),
|
||||
GoOS: meta.GetGoOS(),
|
||||
Kernel: meta.GetKernel(),
|
||||
Platform: meta.GetPlatform(),
|
||||
OS: meta.GetOS(),
|
||||
OSVersion: meta.GetOSVersion(),
|
||||
WtVersion: meta.GetNetbirdVersion(),
|
||||
UIVersion: meta.GetUiVersion(),
|
||||
KernelVersion: meta.GetKernelVersion(),
|
||||
SystemSerialNumber: meta.GetSysSerialNumber(),
|
||||
SystemProductName: meta.GetSysProductName(),
|
||||
SystemManufacturer: meta.GetSysManufacturer(),
|
||||
Environment: nbpeer.Environment{
|
||||
Cloud: meta.GetEnvironment().GetCloud(),
|
||||
Platform: meta.GetEnvironment().GetPlatform(),
|
||||
},
|
||||
Flags: nbpeer.Flags{
|
||||
RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
|
||||
ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
|
||||
DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
|
||||
DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
|
||||
DisableDNS: meta.GetFlags().GetDisableDNS(),
|
||||
DisableFirewall: meta.GetFlags().GetDisableFirewall(),
|
||||
BlockLANAccess: meta.GetFlags().GetBlockLANAccess(),
|
||||
BlockInbound: meta.GetFlags().GetBlockInbound(),
|
||||
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
|
||||
},
|
||||
},
|
||||
SetupKey: setupKey.Key,
|
||||
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
|
||||
login := func() error {
|
||||
|
||||
@@ -20,7 +20,10 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -176,7 +179,6 @@ func startServer(
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
@@ -199,13 +201,18 @@ func startServer(
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(str)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(ctx, str)
|
||||
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
|
||||
accountManager, err := server.BuildManager(
|
||||
context.Background(),
|
||||
str,
|
||||
peersUpdateManager,
|
||||
networkMapController,
|
||||
nil,
|
||||
"",
|
||||
"netbird.selfhosted",
|
||||
eventStore,
|
||||
nil,
|
||||
false,
|
||||
@@ -220,18 +227,18 @@ func startServer(
|
||||
}
|
||||
|
||||
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(
|
||||
context.Background(),
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(
|
||||
config,
|
||||
accountManager,
|
||||
settingsMockManager,
|
||||
peersUpdateManager,
|
||||
updateManager,
|
||||
secretsManager,
|
||||
nil,
|
||||
&manager.EphemeralManager{},
|
||||
nil,
|
||||
server.MockIntegratedValidator{},
|
||||
networkMapController,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating management server: %v", err)
|
||||
|
||||
@@ -38,7 +38,7 @@ type MockAccountManager struct {
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
@@ -94,7 +94,7 @@ type MockAccountManager struct {
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
|
||||
RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
|
||||
@@ -125,9 +125,10 @@ type MockAccountManager struct {
|
||||
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
@@ -177,11 +178,11 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
|
||||
@@ -746,11 +747,11 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
return am.SyncPeerFunc(ctx, sync, accountID)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
||||
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
||||
}
|
||||
|
||||
// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface
|
||||
@@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error {
|
||||
if am.RecalculateNetworkMapCacheFunc != nil {
|
||||
return am.RecalculateNetworkMapCacheFunc(ctx, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -785,7 +787,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
|
||||
return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -975,7 +983,7 @@ func TestValidateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
var newNameServerGroupA *nbdns.NameServerGroup
|
||||
var newNameServerGroupB *nbdns.NameServerGroup
|
||||
@@ -994,9 +1002,9 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a nameserver group with a distribution group no peers should not update account peers
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@@ -23,7 +21,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -31,7 +28,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -106,11 +102,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start))
|
||||
}()
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var expired bool
|
||||
@@ -145,9 +136,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
|
||||
if expired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
am.networkMapController.OnPeerUpdated(accountID, peer)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -203,7 +192,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var peerGroupList []string
|
||||
var requiresPeerUpdates bool
|
||||
var peerLabelChanged bool
|
||||
var sshChanged bool
|
||||
var loginExpirationChanged bool
|
||||
@@ -226,9 +214,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
dnsDomain = am.GetDNSDomain(settings)
|
||||
dnsDomain = am.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
|
||||
update, _, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -321,11 +309,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
}
|
||||
|
||||
if peerLabelChanged || requiresPeerUpdates {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
} else if sshChanged {
|
||||
am.UpdateAccountPeer(ctx, accountID, peer.ID)
|
||||
}
|
||||
am.networkMapController.OnPeerUpdated(accountID, peer)
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
@@ -381,8 +365,13 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if userID != activity.SystemInitiator {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
|
||||
}
|
||||
|
||||
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -390,41 +379,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
|
||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
groups := make(map[string][]string)
|
||||
for groupID, group := range account.Groups {
|
||||
groups[groupID] = group.Peers
|
||||
}
|
||||
|
||||
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
return networkMap, nil
|
||||
return am.networkMapController.GetNetworkMap(ctx, peerID)
|
||||
}
|
||||
|
||||
// GetPeerNetwork returns the Network for a given peer
|
||||
@@ -683,16 +638,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
}
|
||||
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
|
||||
opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings))
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
return p, nmap, pc, err
|
||||
}
|
||||
|
||||
func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
|
||||
@@ -707,12 +665,7 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start))
|
||||
}()
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
var peer *nbpeer.Peer
|
||||
var peerNotValid bool
|
||||
var isStatusChanged bool
|
||||
@@ -722,7 +675,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -772,14 +725,14 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
am.networkMapController.OnPeerUpdated(accountID, peer)
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
|
||||
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
@@ -831,6 +784,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
startTransaction := time.Now()
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
@@ -900,11 +854,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
|
||||
|
||||
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
am.networkMapController.OnPeerUpdated(accountID, peer)
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
|
||||
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
|
||||
return p, nmap, pc, err
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
@@ -996,57 +953,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start))
|
||||
}()
|
||||
|
||||
if isRequiresApproval {
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
emptyMap := &types.NetworkMap{
|
||||
Network: network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, nil, nil
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
networkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
return peer, networkMap, postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error {
|
||||
err := checkAuth(ctx, user.Id, peer)
|
||||
if err != nil {
|
||||
@@ -1070,7 +976,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
|
||||
return fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings)))
|
||||
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.networkMapController.GetDNSDomain(settings)))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1166,209 +1072,17 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
// UpdateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
|
||||
hasPeersConnected := false
|
||||
for _, peer := range account.Peers {
|
||||
if am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
hasPeersConnected = true
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if !hasPeersConnected {
|
||||
return
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
dnsCache := &DNSConfigCache{}
|
||||
dnsDomain := am.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
wg.Wait()
|
||||
if am.metrics != nil {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
mu sync.Mutex
|
||||
next *time.Timer
|
||||
update atomic.Bool
|
||||
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
}()
|
||||
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
// UpdateAccountPeer updates a single peer that belongs to an account.
|
||||
// Should be called when changes need to be synced to a specific peer only.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) {
|
||||
if !am.peersUpdateManager.HasChannel(peerId) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err)
|
||||
return
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerId)
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId)
|
||||
return
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsCache := &DNSConfigCache{}
|
||||
dnsDomain := am.GetDNSDomain(account.Settings)
|
||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
postureChecks, err := am.getPeerPostureChecks(account, peerId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
|
||||
return
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(peerId)
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
|
||||
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
_ = am.networkMapController.UpdateAccountPeer(ctx, accountId, peerId)
|
||||
}
|
||||
|
||||
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
@@ -1523,14 +1237,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dnsDomain := am.GetDNSDomain(settings)
|
||||
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion)
|
||||
dnsDomain := am.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
for _, peer := range peers {
|
||||
if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
|
||||
@@ -1564,25 +1271,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: network.CurrentSerial(),
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
DNSConfig: &proto.DNSConfig{
|
||||
ForwarderPort: dnsFwdPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkMap: &types.NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
peerDeletedEvents = append(peerDeletedEvents, func() {
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
|
||||
})
|
||||
@@ -1591,14 +1279,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
return peerDeletedEvents, nil
|
||||
}
|
||||
|
||||
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
labelMap := make(map[string]struct{}, len(existingLabels))
|
||||
for _, label := range existingLabels {
|
||||
labelMap[label] = struct{}{}
|
||||
}
|
||||
return labelMap
|
||||
}
|
||||
|
||||
// validatePeerDelete checks if the peer can be deleted.
|
||||
func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error {
|
||||
linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -25,10 +24,14 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -168,7 +171,16 @@ func TestPeer_SessionExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func testGetNetworkMapGeneral(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -240,7 +252,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
// TODO: disable until we start use policy again
|
||||
t.Skip()
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -417,7 +429,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -478,7 +490,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -665,7 +677,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -733,12 +745,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) {
|
||||
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, *update_channel.PeersUpdateManager, string, string, error) {
|
||||
b.Helper()
|
||||
|
||||
manager, err := createManager(b)
|
||||
manager, updateManager, err := createManager(b)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
accountID := "test_account"
|
||||
@@ -789,7 +801,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
|
||||
ips := account.GetTakenIPs()
|
||||
peerIP, err := types.AllocatePeerIP(account.Network.Net, ips)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
peerKey, _ := wgtypes.GeneratePrivateKey()
|
||||
@@ -895,10 +907,10 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
return manager, accountID, regularUser, nil
|
||||
return manager, updateManager, accountID, regularUser, nil
|
||||
}
|
||||
|
||||
func BenchmarkGetPeers(b *testing.B) {
|
||||
@@ -919,7 +931,7 @@ func BenchmarkGetPeers(b *testing.B) {
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
manager, _, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
@@ -959,7 +971,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
@@ -971,14 +983,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
updateManager.CreateChannel(ctx, peerID)
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
|
||||
@@ -1003,7 +1011,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers(t *testing.T) {
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func testUpdateAccountPeers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
@@ -1019,7 +1036,7 @@ func TestUpdateAccountPeers(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
|
||||
manager, updateManager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
@@ -1031,20 +1048,19 @@ func TestUpdateAccountPeers(t *testing.T) {
|
||||
t.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
peerChannels := make(map[string]chan *network_map.UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID)
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
manager.UpdateAccountPeers(ctx, account.Id)
|
||||
|
||||
for _, channel := range peerChannels {
|
||||
update := <-channel
|
||||
assert.Nil(t, update.Update.NetbirdConfig)
|
||||
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
|
||||
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
|
||||
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
|
||||
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1079,7 +1095,7 @@ func TestToSyncResponse(t *testing.T) {
|
||||
DNSLabel: "peer1",
|
||||
SSHKey: "peer1-ssh-key",
|
||||
}
|
||||
turnRelayToken := &Token{
|
||||
turnRelayToken := &grpc.Token{
|
||||
Payload: "turn-user",
|
||||
Signature: "turn-pass",
|
||||
}
|
||||
@@ -1159,9 +1175,9 @@ func TestToSyncResponse(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
|
||||
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
|
||||
response := grpc.ToSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
|
||||
|
||||
assert.NotNil(t, response)
|
||||
// assert peer config
|
||||
@@ -1271,7 +1287,12 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1351,7 +1372,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1499,7 +1525,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1548,6 +1579,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LoginPeer(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
@@ -1573,7 +1605,12 @@ func Test_LoginPeer(t *testing.T) {
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1706,7 +1743,7 @@ func Test_LoginPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
|
||||
require.NoError(t, err)
|
||||
@@ -1763,13 +1800,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
var peer5 *nbpeer.Peer
|
||||
var peer6 *nbpeer.Peer
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
|
||||
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
|
||||
t.Skip("Currently all updates will trigger a network map")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
@@ -1871,6 +1909,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("validator requires no update", func(t *testing.T) {
|
||||
t.Skip("Currently all updates will trigger a network map")
|
||||
|
||||
requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
return update, false, nil
|
||||
}
|
||||
@@ -2072,7 +2112,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_DeletePeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -2169,7 +2209,7 @@ func Test_IsUniqueConstraintError(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_AddPeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -2257,136 +2297,8 @@ func Test_AddPeer(t *testing.T) {
|
||||
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
|
||||
}
|
||||
|
||||
func TestBufferUpdateAccountPeers(t *testing.T) {
|
||||
const (
|
||||
peersCount = 1000
|
||||
updateAccountInterval = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
|
||||
uapLastRun, dpLastRun atomic.Int64
|
||||
|
||||
totalNewRuns, totalOldRuns int
|
||||
)
|
||||
|
||||
uap := func(ctx context.Context, accountID string) {
|
||||
updatePeersDeleted.Store(deletedPeers.Load())
|
||||
updatePeersRuns.Add(1)
|
||||
uapLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Run("new approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := mu.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
uap(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
b.next = time.AfterFunc(updateAccountInterval, func() {
|
||||
uap(ctx, accountID)
|
||||
})
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalNewRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
|
||||
t.Run("old approach", func(t *testing.T) {
|
||||
updatePeersRuns.Store(0)
|
||||
updatePeersDeleted.Store(0)
|
||||
deletedPeers.Store(0)
|
||||
|
||||
var mustore sync.Map
|
||||
bufupd := func(ctx context.Context, accountID string) {
|
||||
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
|
||||
b := mu.(*sync.Mutex)
|
||||
|
||||
if !b.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(updateAccountInterval)
|
||||
b.Unlock()
|
||||
uap(ctx, accountID)
|
||||
}()
|
||||
}
|
||||
dp := func(ctx context.Context, accountID, peerID, userID string) error {
|
||||
deletedPeers.Add(1)
|
||||
dpLastRun.Store(time.Now().UnixMilli())
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
bufupd(ctx, accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
am := mock_server.MockAccountManager{
|
||||
UpdateAccountPeersFunc: uap,
|
||||
BufferUpdateAccountPeersFunc: bufupd,
|
||||
DeletePeerFunc: dp,
|
||||
}
|
||||
empty := ""
|
||||
for range peersCount {
|
||||
//nolint
|
||||
am.DeletePeer(context.Background(), empty, empty, empty)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
|
||||
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
|
||||
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
|
||||
|
||||
totalOldRuns = int(updatePeersRuns.Load())
|
||||
})
|
||||
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
}
|
||||
|
||||
func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -2423,7 +2335,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -2457,7 +2369,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -2522,7 +2434,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -252,31 +251,3 @@ func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []strin
|
||||
|
||||
return validIDs
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if shouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result[i] = fwRule
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
@@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
@@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.250.202",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.250.202",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
@@ -833,10 +929,58 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 1)
|
||||
assert.Len(t, firewallRules, 7)
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.21.56",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
@@ -991,7 +1135,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
|
||||
}
|
||||
|
||||
func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
g := []*types.Group{
|
||||
{
|
||||
@@ -1020,9 +1164,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
var policyWithGroupRulesNoPeers *types.Policy
|
||||
|
||||
@@ -2,19 +2,15 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -136,27 +132,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
|
||||
peerPostureChecks := make(map[string]*posture.Checks)
|
||||
|
||||
if len(account.PostureChecks) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return maps.Values(peerPostureChecks), nil
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -211,50 +186,6 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
|
||||
return nil
|
||||
}
|
||||
|
||||
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isInGroup {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
|
||||
if postureCheck == nil {
|
||||
return errors.New("failed to add policy posture checks: posture checks not found")
|
||||
}
|
||||
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group := account.GetGroup(sourceGroup)
|
||||
if group == nil {
|
||||
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
|
||||
@@ -21,7 +21,7 @@ const (
|
||||
)
|
||||
|
||||
func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
am, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
}
|
||||
|
||||
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
g := []*types.Group{
|
||||
{
|
||||
@@ -147,9 +147,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
postureCheckA := &posture.Checks{
|
||||
@@ -359,9 +359,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
// Updating linked posture check to policy where destination has peers but source does not
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
updMsg1 := updateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
@@ -445,7 +445,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
||||
account, err := initTestPostureChecksAccount(manager)
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -372,103 +371,12 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID
|
||||
return groupsMap, nil
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
// getPlaceholderIP returns a placeholder IP address for the route if domains are used
|
||||
func getPlaceholderIP() netip.Prefix {
|
||||
// Using an IP from the documentation range to minimize impact in case older clients try to set a route
|
||||
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
|
||||
}
|
||||
|
||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: getProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
PortInfo: getProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||
func getProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
// getProtoAction converts the action to proto.RuleAction.
|
||||
func getProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
// areRouteChangesAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers.
|
||||
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -432,7 +434,7 @@ func TestCreateRoute(t *testing.T) {
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
am, err := createRouterManager(t)
|
||||
am, _, err := createRouterManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
@@ -922,7 +924,7 @@ func TestSaveRoute(t *testing.T) {
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
am, err := createRouterManager(t)
|
||||
am, _, err := createRouterManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
@@ -1024,7 +1026,7 @@ func TestDeleteRoute(t *testing.T) {
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
am, err := createRouterManager(t)
|
||||
am, _, err := createRouterManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
@@ -1071,7 +1073,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
AccessControlGroups: []string{routeGroup1},
|
||||
}
|
||||
|
||||
am, err := createRouterManager(t)
|
||||
am, _, err := createRouterManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
@@ -1163,7 +1165,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
AccessControlGroups: []string{routeGroup1},
|
||||
}
|
||||
|
||||
am, err := createRouterManager(t)
|
||||
am, _, err := createRouterManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
@@ -1250,11 +1252,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1")
|
||||
}
|
||||
|
||||
func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
|
||||
t.Helper()
|
||||
store, err := createRouterStore(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
@@ -1285,7 +1287,16 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
|
||||
am, err := BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return am, updateManager, nil
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (store.Store, error) {
|
||||
@@ -1948,7 +1959,7 @@ func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFi
|
||||
}
|
||||
|
||||
func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
manager, err := createRouterManager(t)
|
||||
manager, updateManager, err := createRouterManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
||||
account, err := initTestRouteAccount(t, manager)
|
||||
@@ -1976,9 +1987,9 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
require.NoError(t, err, "failed to create group %s", group.Name)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1ID)
|
||||
})
|
||||
|
||||
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update
|
||||
|
||||
@@ -5,6 +5,9 @@ package settings
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
|
||||
@@ -45,6 +48,11 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start))
|
||||
}()
|
||||
|
||||
if userID != activity.SystemInitiator {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||
if err != nil {
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
)
|
||||
|
||||
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -198,7 +198,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetSetupKeys(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -396,7 +396,7 @@ func TestSetupKey_Copy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
@@ -420,9 +420,9 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
var setupKey *types.SetupKey
|
||||
@@ -465,7 +465,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1089
management/server/store/sql_store_get_account_test.go
Normal file
1089
management/server/store/sql_store_get_account_test.go
Normal file
File diff suppressed because it is too large
Load Diff
951
management/server/store/sqlstore_bench_test.go
Normal file
951
management/server/store/sqlstore_bench_test.go
Normal file
@@ -0,0 +1,951 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
if elapsed > 1*time.Second {
|
||||
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
|
||||
}
|
||||
}()
|
||||
|
||||
var account types.Account
|
||||
result := s.db.Model(&account).
|
||||
Omit("GroupsG").
|
||||
Preload("UsersG.PATsG"). // have to be specified as this is nested reference
|
||||
Preload(clause.Associations).
|
||||
Take(&account, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
||||
for i, policy := range account.Policies {
|
||||
var rules []*types.PolicyRule
|
||||
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "rule not found")
|
||||
}
|
||||
account.Policies[i].Rules = rules
|
||||
}
|
||||
|
||||
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
||||
for _, key := range account.SetupKeysG {
|
||||
account.SetupKeys[key.Key] = key.Copy()
|
||||
}
|
||||
account.SetupKeysG = nil
|
||||
|
||||
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
|
||||
for _, peer := range account.PeersG {
|
||||
account.Peers[peer.ID] = peer.Copy()
|
||||
}
|
||||
account.PeersG = nil
|
||||
|
||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||
for _, user := range account.UsersG {
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
|
||||
for _, pat := range user.PATsG {
|
||||
user.PATs[pat.ID] = pat.Copy()
|
||||
}
|
||||
account.Users[user.Id] = user.Copy()
|
||||
}
|
||||
account.UsersG = nil
|
||||
|
||||
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
||||
for _, group := range account.GroupsG {
|
||||
account.Groups[group.ID] = group.Copy()
|
||||
}
|
||||
account.GroupsG = nil
|
||||
|
||||
var groupPeers []types.GroupPeer
|
||||
s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
|
||||
Find(&groupPeers)
|
||||
for _, groupPeer := range groupPeers {
|
||||
if group, ok := account.Groups[groupPeer.GroupID]; ok {
|
||||
group.Peers = append(group.Peers, groupPeer.PeerID)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
||||
for _, route := range account.RoutesG {
|
||||
account.Routes[route.ID] = route.Copy()
|
||||
}
|
||||
account.RoutesG = nil
|
||||
|
||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
||||
for _, ns := range account.NameServerGroupsG {
|
||||
account.NameServerGroups[ns.ID] = ns.Copy()
|
||||
}
|
||||
account.NameServerGroupsG = nil
|
||||
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
if elapsed > 1*time.Second {
|
||||
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
|
||||
}
|
||||
}()
|
||||
|
||||
var account types.Account
|
||||
result := s.db.Model(&account).
|
||||
Preload("UsersG.PATsG"). // have to be specified as this is nested reference
|
||||
Preload("Policies.Rules").
|
||||
Preload("SetupKeysG").
|
||||
Preload("PeersG").
|
||||
Preload("UsersG").
|
||||
Preload("GroupsG.GroupPeers").
|
||||
Preload("RoutesG").
|
||||
Preload("NameServerGroupsG").
|
||||
Preload("PostureChecks").
|
||||
Preload("Networks").
|
||||
Preload("NetworkRouters").
|
||||
Preload("NetworkResources").
|
||||
Preload("Onboarding").
|
||||
Take(&account, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
||||
for _, key := range account.SetupKeysG {
|
||||
if key.UpdatedAt.IsZero() {
|
||||
key.UpdatedAt = key.CreatedAt
|
||||
}
|
||||
if key.AutoGroups == nil {
|
||||
key.AutoGroups = []string{}
|
||||
}
|
||||
account.SetupKeys[key.Key] = &key
|
||||
}
|
||||
account.SetupKeysG = nil
|
||||
|
||||
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
|
||||
for _, peer := range account.PeersG {
|
||||
account.Peers[peer.ID] = &peer
|
||||
}
|
||||
account.PeersG = nil
|
||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||
for _, user := range account.UsersG {
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
|
||||
for _, pat := range user.PATsG {
|
||||
pat.UserID = ""
|
||||
user.PATs[pat.ID] = &pat
|
||||
}
|
||||
if user.AutoGroups == nil {
|
||||
user.AutoGroups = []string{}
|
||||
}
|
||||
account.Users[user.Id] = &user
|
||||
user.PATsG = nil
|
||||
}
|
||||
account.UsersG = nil
|
||||
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
||||
for _, group := range account.GroupsG {
|
||||
group.Peers = make([]string, len(group.GroupPeers))
|
||||
for i, gp := range group.GroupPeers {
|
||||
group.Peers[i] = gp.PeerID
|
||||
}
|
||||
if group.Resources == nil {
|
||||
group.Resources = []types.Resource{}
|
||||
}
|
||||
account.Groups[group.ID] = group
|
||||
}
|
||||
account.GroupsG = nil
|
||||
|
||||
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
||||
for _, route := range account.RoutesG {
|
||||
account.Routes[route.ID] = &route
|
||||
}
|
||||
account.RoutesG = nil
|
||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
||||
for _, ns := range account.NameServerGroupsG {
|
||||
ns.AccountID = ""
|
||||
if ns.NameServers == nil {
|
||||
ns.NameServers = []nbdns.NameServer{}
|
||||
}
|
||||
if ns.Groups == nil {
|
||||
ns.Groups = []string{}
|
||||
}
|
||||
if ns.Domains == nil {
|
||||
ns.Domains = []string{}
|
||||
}
|
||||
account.NameServerGroups[ns.ID] = &ns
|
||||
}
|
||||
account.NameServerGroupsG = nil
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse database config: %w", err)
|
||||
}
|
||||
|
||||
config.MaxConns = 12
|
||||
config.MinConns = 2
|
||||
config.MaxConnLifetime = time.Hour
|
||||
config.HealthCheckPeriod = time.Minute
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("unable to ping database: %w", err)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
|
||||
cleanup, dsn, err := testutil.CreatePostgresTestContainer()
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create test container: %v", err)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
b.Fatalf("failed to connect database: %v", err)
|
||||
}
|
||||
|
||||
pool, err := connectDBforTest(context.Background(), dsn)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to connect database: %v", err)
|
||||
}
|
||||
|
||||
models := []interface{}{
|
||||
&types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{},
|
||||
&types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
||||
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
|
||||
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
||||
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
&types.AccountOnboarding{},
|
||||
}
|
||||
|
||||
for i := len(models) - 1; i >= 0; i-- {
|
||||
err := db.Migrator().DropTable(models[i])
|
||||
if err != nil {
|
||||
b.Fatalf("failed to drop table: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(models...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to migrate database: %v", err)
|
||||
}
|
||||
|
||||
store := &SqlStore{
|
||||
db: db,
|
||||
pool: pool,
|
||||
}
|
||||
|
||||
const (
|
||||
accountID = "benchmark-account-id"
|
||||
numUsers = 20
|
||||
numPatsPerUser = 3
|
||||
numSetupKeys = 25
|
||||
numPeers = 200
|
||||
numGroups = 30
|
||||
numPolicies = 50
|
||||
numRulesPerPolicy = 10
|
||||
numRoutes = 40
|
||||
numNSGroups = 10
|
||||
numPostureChecks = 15
|
||||
numNetworks = 5
|
||||
numNetworkRouters = 5
|
||||
numNetworkResources = 10
|
||||
)
|
||||
|
||||
_, ipNet, _ := net.ParseCIDR("100.64.0.0/10")
|
||||
acc := types.Account{
|
||||
Id: accountID,
|
||||
CreatedBy: "benchmark-user",
|
||||
CreatedAt: time.Now(),
|
||||
Domain: "benchmark.com",
|
||||
IsDomainPrimaryAccount: true,
|
||||
Network: &types.Network{
|
||||
Identifier: "benchmark-net",
|
||||
Net: *ipNet,
|
||||
Serial: 1,
|
||||
},
|
||||
DNSSettings: types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"group-disabled-1"},
|
||||
},
|
||||
Settings: &types.Settings{},
|
||||
}
|
||||
if err := db.Create(&acc).Error; err != nil {
|
||||
b.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
var setupKeys []types.SetupKey
|
||||
for i := 0; i < numSetupKeys; i++ {
|
||||
setupKeys = append(setupKeys, types.SetupKey{
|
||||
Id: fmt.Sprintf("keyid-%d", i),
|
||||
AccountID: accountID,
|
||||
Key: fmt.Sprintf("key-%d", i),
|
||||
Name: fmt.Sprintf("Benchmark Key %d", i),
|
||||
ExpiresAt: &time.Time{},
|
||||
})
|
||||
}
|
||||
if err := db.Create(&setupKeys).Error; err != nil {
|
||||
b.Fatalf("create setup keys: %v", err)
|
||||
}
|
||||
|
||||
var peers []nbpeer.Peer
|
||||
for i := 0; i < numPeers; i++ {
|
||||
peers = append(peers, nbpeer.Peer{
|
||||
ID: fmt.Sprintf("peer-%d", i),
|
||||
AccountID: accountID,
|
||||
Key: fmt.Sprintf("peerkey-%d", i),
|
||||
IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)),
|
||||
Name: fmt.Sprintf("peer-name-%d", i),
|
||||
Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()},
|
||||
})
|
||||
}
|
||||
if err := db.Create(&peers).Error; err != nil {
|
||||
b.Fatalf("create peers: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < numUsers; i++ {
|
||||
userID := fmt.Sprintf("user-%d", i)
|
||||
user := types.User{Id: userID, AccountID: accountID}
|
||||
if err := db.Create(&user).Error; err != nil {
|
||||
b.Fatalf("create user %s: %v", userID, err)
|
||||
}
|
||||
|
||||
var pats []types.PersonalAccessToken
|
||||
for j := 0; j < numPatsPerUser; j++ {
|
||||
pats = append(pats, types.PersonalAccessToken{
|
||||
ID: fmt.Sprintf("pat-%d-%d", i, j),
|
||||
UserID: userID,
|
||||
Name: fmt.Sprintf("PAT %d for User %d", j, i),
|
||||
})
|
||||
}
|
||||
if err := db.Create(&pats).Error; err != nil {
|
||||
b.Fatalf("create pats for user %s: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
var groups []*types.Group
|
||||
for i := 0; i < numGroups; i++ {
|
||||
groups = append(groups, &types.Group{
|
||||
ID: fmt.Sprintf("group-%d", i),
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("Group %d", i),
|
||||
})
|
||||
}
|
||||
if err := db.Create(&groups).Error; err != nil {
|
||||
b.Fatalf("create groups: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < numPolicies; i++ {
|
||||
policyID := fmt.Sprintf("policy-%d", i)
|
||||
policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true}
|
||||
if err := db.Create(&policy).Error; err != nil {
|
||||
b.Fatalf("create policy %s: %v", policyID, err)
|
||||
}
|
||||
|
||||
var rules []*types.PolicyRule
|
||||
for j := 0; j < numRulesPerPolicy; j++ {
|
||||
rules = append(rules, &types.PolicyRule{
|
||||
ID: fmt.Sprintf("rule-%d-%d", i, j),
|
||||
PolicyID: policyID,
|
||||
Name: fmt.Sprintf("Rule %d for Policy %d", j, i),
|
||||
Enabled: true,
|
||||
Protocol: "all",
|
||||
})
|
||||
}
|
||||
if err := db.Create(&rules).Error; err != nil {
|
||||
b.Fatalf("create rules for policy %s: %v", policyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
var routes []route.Route
|
||||
for i := 0; i < numRoutes; i++ {
|
||||
routes = append(routes, route.Route{
|
||||
ID: route.ID(fmt.Sprintf("route-%d", i)),
|
||||
AccountID: accountID,
|
||||
Description: fmt.Sprintf("Route %d", i),
|
||||
Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)),
|
||||
Enabled: true,
|
||||
})
|
||||
}
|
||||
if err := db.Create(&routes).Error; err != nil {
|
||||
b.Fatalf("create routes: %v", err)
|
||||
}
|
||||
|
||||
var nsGroups []nbdns.NameServerGroup
|
||||
for i := 0; i < numNSGroups; i++ {
|
||||
nsGroups = append(nsGroups, nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("nsg-%d", i),
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("NS Group %d", i),
|
||||
Description: "Benchmark NS Group",
|
||||
Enabled: true,
|
||||
})
|
||||
}
|
||||
if err := db.Create(&nsGroups).Error; err != nil {
|
||||
b.Fatalf("create nsgroups: %v", err)
|
||||
}
|
||||
|
||||
var postureChecks []*posture.Checks
|
||||
for i := 0; i < numPostureChecks; i++ {
|
||||
postureChecks = append(postureChecks, &posture.Checks{
|
||||
ID: fmt.Sprintf("pc-%d", i),
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("Posture Check %d", i),
|
||||
})
|
||||
}
|
||||
if err := db.Create(&postureChecks).Error; err != nil {
|
||||
b.Fatalf("create posture checks: %v", err)
|
||||
}
|
||||
|
||||
var networks []*networkTypes.Network
|
||||
for i := 0; i < numNetworks; i++ {
|
||||
networks = append(networks, &networkTypes.Network{
|
||||
ID: fmt.Sprintf("nettype-%d", i),
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("Network Type %d", i),
|
||||
})
|
||||
}
|
||||
if err := db.Create(&networks).Error; err != nil {
|
||||
b.Fatalf("create networks: %v", err)
|
||||
}
|
||||
|
||||
var networkRouters []*routerTypes.NetworkRouter
|
||||
for i := 0; i < numNetworkRouters; i++ {
|
||||
networkRouters = append(networkRouters, &routerTypes.NetworkRouter{
|
||||
ID: fmt.Sprintf("router-%d", i),
|
||||
AccountID: accountID,
|
||||
NetworkID: networks[i%numNetworks].ID,
|
||||
Peer: peers[i%numPeers].ID,
|
||||
})
|
||||
}
|
||||
if err := db.Create(&networkRouters).Error; err != nil {
|
||||
b.Fatalf("create network routers: %v", err)
|
||||
}
|
||||
|
||||
var networkResources []*resourceTypes.NetworkResource
|
||||
for i := 0; i < numNetworkResources; i++ {
|
||||
networkResources = append(networkResources, &resourceTypes.NetworkResource{
|
||||
ID: fmt.Sprintf("resource-%d", i),
|
||||
AccountID: accountID,
|
||||
NetworkID: networks[i%numNetworks].ID,
|
||||
Name: fmt.Sprintf("Resource %d", i),
|
||||
})
|
||||
}
|
||||
if err := db.Create(&networkResources).Error; err != nil {
|
||||
b.Fatalf("create network resources: %v", err)
|
||||
}
|
||||
|
||||
onboarding := types.AccountOnboarding{
|
||||
AccountID: accountID,
|
||||
OnboardingFlowPending: true,
|
||||
}
|
||||
if err := db.Create(&onboarding).Error; err != nil {
|
||||
b.Fatalf("create onboarding: %v", err)
|
||||
}
|
||||
|
||||
return store, cleanup, accountID
|
||||
}
|
||||
|
||||
func BenchmarkGetAccount(b *testing.B) {
|
||||
store, cleanup, accountID := setupBenchmarkDB(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
b.Run("old", func(b *testing.B) {
|
||||
for range b.N {
|
||||
_, err := store.GetAccountSlow(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("GetAccountSlow failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.Run("gorm opt", func(b *testing.B) {
|
||||
for range b.N {
|
||||
_, err := store.GetAccountGormOpt(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("GetAccountFast failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.Run("raw", func(b *testing.B) {
|
||||
for range b.N {
|
||||
_, err := store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("GetAccountPureSQL failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
store.pool.Close()
|
||||
}
|
||||
|
||||
func TestAccountEquivalence(t *testing.T) {
|
||||
store, cleanup, accountID := setupBenchmarkDB(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
type getAccountFunc func(context.Context, string) (*types.Account, error)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedF getAccountFunc
|
||||
actualF getAccountFunc
|
||||
}{
|
||||
{"old vs new", store.GetAccountSlow, store.GetAccountGormOpt},
|
||||
{"old vs raw", store.GetAccountSlow, store.GetAccount},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
expected, errOld := tt.expectedF(ctx, accountID)
|
||||
assert.NoError(t, errOld, "expected function should not return an error")
|
||||
assert.NotNil(t, expected, "expected should not be nil")
|
||||
|
||||
actual, errNew := tt.actualF(ctx, accountID)
|
||||
assert.NoError(t, errNew, "actual function should not return an error")
|
||||
assert.NotNil(t, actual, "actual should not be nil")
|
||||
testAccountEquivalence(t, expected, actual)
|
||||
})
|
||||
}
|
||||
|
||||
expected, errOld := store.GetAccountSlow(ctx, accountID)
|
||||
assert.NoError(t, errOld, "GetAccountSlow should not return an error")
|
||||
assert.NotNil(t, expected, "expected should not be nil")
|
||||
|
||||
actual, errNew := store.GetAccount(ctx, accountID)
|
||||
assert.NoError(t, errNew, "GetAccount (new) should not return an error")
|
||||
assert.NotNil(t, actual, "actual should not be nil")
|
||||
}
|
||||
|
||||
func testAccountEquivalence(t *testing.T, expected, actual *types.Account) {
|
||||
assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal")
|
||||
assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal")
|
||||
assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second")
|
||||
assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal")
|
||||
assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal")
|
||||
assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal")
|
||||
assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal")
|
||||
assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal")
|
||||
assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal")
|
||||
|
||||
assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements")
|
||||
for key, oldVal := range expected.SetupKeys {
|
||||
newVal, ok := actual.SetupKeys[key]
|
||||
assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key)
|
||||
assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements")
|
||||
for key, oldVal := range expected.Peers {
|
||||
newVal, ok := actual.Peers[key]
|
||||
assert.True(t, ok, "Peer with ID '%s' should exist in new account", key)
|
||||
assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements")
|
||||
for key, oldUser := range expected.Users {
|
||||
newUser, ok := actual.Users[key]
|
||||
assert.True(t, ok, "User with ID '%s' should exist in new account", key)
|
||||
|
||||
assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key)
|
||||
for patKey, oldPAT := range oldUser.PATs {
|
||||
newPAT, patOk := newUser.PATs[patKey]
|
||||
assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key)
|
||||
assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key)
|
||||
}
|
||||
|
||||
oldUser.PATs = nil
|
||||
newUser.PATs = nil
|
||||
assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements")
|
||||
for key, oldVal := range expected.Groups {
|
||||
newVal, ok := actual.Groups[key]
|
||||
assert.True(t, ok, "Group with ID '%s' should exist in new account", key)
|
||||
sort.Strings(oldVal.Peers)
|
||||
sort.Strings(newVal.Peers)
|
||||
assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements")
|
||||
for key, oldVal := range expected.Routes {
|
||||
newVal, ok := actual.Routes[key]
|
||||
assert.True(t, ok, "Route with ID '%s' should exist in new account", key)
|
||||
assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements")
|
||||
for key, oldVal := range expected.NameServerGroups {
|
||||
newVal, ok := actual.NameServerGroups[key]
|
||||
assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key)
|
||||
assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements")
|
||||
sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID })
|
||||
sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID })
|
||||
for i := range expected.Policies {
|
||||
sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID })
|
||||
sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID })
|
||||
assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements")
|
||||
sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID })
|
||||
sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID })
|
||||
for i := range expected.PostureChecks {
|
||||
assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements")
|
||||
sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID })
|
||||
sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID })
|
||||
for i := range expected.Networks {
|
||||
assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements")
|
||||
sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID })
|
||||
sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID })
|
||||
for i := range expected.NetworkRouters {
|
||||
assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID)
|
||||
}
|
||||
|
||||
assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements")
|
||||
sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID })
|
||||
sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID })
|
||||
for i := range expected.NetworkResources {
|
||||
assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
account, err := s.getAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 12)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
keys, err := s.getSetupKeys(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.SetupKeysG = keys
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
peers, err := s.getPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.PeersG = peers
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
users, err := s.getUsers(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.UsersG = users
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
groups, err := s.getGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.GroupsG = groups
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
policies, err := s.getPolicies(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.Policies = policies
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
routes, err := s.getRoutes(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.RoutesG = routes
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
nsgs, err := s.getNameServerGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.NameServerGroupsG = nsgs
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
checks, err := s.getPostureChecks(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.PostureChecks = checks
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
networks, err := s.getNetworks(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.Networks = networks
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
routers, err := s.getNetworkRouters(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.NetworkRouters = routers
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resources, err := s.getNetworkResources(ctx, accountID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
account.NetworkResources = resources
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := s.getAccountOnboarding(ctx, accountID, account)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
for e := range errChan {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
}
|
||||
|
||||
var userIDs []string
|
||||
for _, u := range account.UsersG {
|
||||
userIDs = append(userIDs, u.Id)
|
||||
}
|
||||
var policyIDs []string
|
||||
for _, p := range account.Policies {
|
||||
policyIDs = append(policyIDs, p.ID)
|
||||
}
|
||||
var groupIDs []string
|
||||
for _, g := range account.GroupsG {
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
wg.Add(3)
|
||||
errChan = make(chan error, 3)
|
||||
|
||||
var pats []types.PersonalAccessToken
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var err error
|
||||
pats, err = s.getPersonalAccessTokens(ctx, userIDs)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
var rules []*types.PolicyRule
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var err error
|
||||
rules, err = s.getPolicyRules(ctx, policyIDs)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
var groupPeers []types.GroupPeer
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var err error
|
||||
groupPeers, err = s.getGroupPeers(ctx, groupIDs)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
for e := range errChan {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
}
|
||||
|
||||
patsByUserID := make(map[string][]*types.PersonalAccessToken)
|
||||
for i := range pats {
|
||||
pat := &pats[i]
|
||||
patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
|
||||
pat.UserID = ""
|
||||
}
|
||||
|
||||
rulesByPolicyID := make(map[string][]*types.PolicyRule)
|
||||
for _, rule := range rules {
|
||||
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
|
||||
}
|
||||
|
||||
peersByGroupID := make(map[string][]string)
|
||||
for _, gp := range groupPeers {
|
||||
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
|
||||
}
|
||||
|
||||
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
||||
for i := range account.SetupKeysG {
|
||||
key := &account.SetupKeysG[i]
|
||||
account.SetupKeys[key.Key] = key
|
||||
}
|
||||
|
||||
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
|
||||
for i := range account.PeersG {
|
||||
peer := &account.PeersG[i]
|
||||
account.Peers[peer.ID] = peer
|
||||
}
|
||||
|
||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||
for i := range account.UsersG {
|
||||
user := &account.UsersG[i]
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken)
|
||||
if userPats, ok := patsByUserID[user.Id]; ok {
|
||||
for j := range userPats {
|
||||
pat := userPats[j]
|
||||
user.PATs[pat.ID] = pat
|
||||
}
|
||||
}
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
for i := range account.Policies {
|
||||
policy := account.Policies[i]
|
||||
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
|
||||
policy.Rules = policyRules
|
||||
}
|
||||
}
|
||||
|
||||
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
||||
for i := range account.GroupsG {
|
||||
group := account.GroupsG[i]
|
||||
if peerIDs, ok := peersByGroupID[group.ID]; ok {
|
||||
group.Peers = peerIDs
|
||||
}
|
||||
account.Groups[group.ID] = group
|
||||
}
|
||||
|
||||
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
||||
for i := range account.RoutesG {
|
||||
route := &account.RoutesG[i]
|
||||
account.Routes[route.ID] = route
|
||||
}
|
||||
|
||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
|
||||
for i := range account.NameServerGroupsG {
|
||||
nsg := &account.NameServerGroupsG[i]
|
||||
nsg.AccountID = ""
|
||||
account.NameServerGroups[nsg.ID] = nsg
|
||||
}
|
||||
|
||||
account.SetupKeysG = nil
|
||||
account.PeersG = nil
|
||||
account.UsersG = nil
|
||||
account.GroupsG = nil
|
||||
account.RoutesG = nil
|
||||
account.NameServerGroupsG = nil
|
||||
|
||||
return account, nil
|
||||
}
|
||||
@@ -468,6 +468,9 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
|
||||
closeConnection := func() {
|
||||
cleanup()
|
||||
store.Close(ctx)
|
||||
if store.pool != nil {
|
||||
store.pool.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return store, closeConnection, nil
|
||||
@@ -487,12 +490,18 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
|
||||
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
db, err := openDBWithRetry(dsn, kind, 5)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err)
|
||||
}
|
||||
|
||||
dsn, cleanup, err := createRandomDB(dsn, db, kind)
|
||||
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -519,12 +528,22 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
|
||||
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
|
||||
db, err := openDBWithRetry(dsn, kind, 5)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err)
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
|
||||
dsn, cleanup, err := createRandomDB(dsn, db, kind)
|
||||
|
||||
sqlDB.Close()
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -537,6 +556,31 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
|
||||
return store, cleanup, nil
|
||||
}
|
||||
|
||||
func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) {
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
|
||||
for i := range maxRetries {
|
||||
switch engine {
|
||||
case types.PostgresStoreEngine:
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
case types.MysqlStoreEngine:
|
||||
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
if i < maxRetries-1 {
|
||||
waitTime := time.Duration(100*(i+1)) * time.Millisecond
|
||||
time.Sleep(waitTime)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
|
||||
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
|
||||
|
||||
@@ -544,21 +588,63 @@ func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(
|
||||
return "", nil, fmt.Errorf("failed to create database: %v", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
originalDSN := dsn
|
||||
|
||||
cleanup := func() {
|
||||
var dropDB *gorm.DB
|
||||
var err error
|
||||
|
||||
switch engine {
|
||||
case types.PostgresStoreEngine:
|
||||
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
|
||||
dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{
|
||||
SkipDefaultTransaction: true,
|
||||
PrepareStmt: false,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetMaxIdleConns(0)
|
||||
sqlDB.SetConnMaxLifetime(time.Second)
|
||||
}
|
||||
|
||||
err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error
|
||||
|
||||
case types.MysqlStoreEngine:
|
||||
// err = killMySQLConnections(dsn, dbName)
|
||||
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
|
||||
dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{
|
||||
SkipDefaultTransaction: true,
|
||||
PrepareStmt: false,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetMaxIdleConns(0)
|
||||
sqlDB.SetConnMaxLifetime(time.Second)
|
||||
}
|
||||
|
||||
err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to drop database %s: %v", dbName, err)
|
||||
panic(err)
|
||||
}
|
||||
sqlDB, _ := db.DB()
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
|
||||
return replaceDBName(dsn, dbName), cleanup, nil
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -87,6 +88,13 @@ type Account struct {
|
||||
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
|
||||
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
||||
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||
|
||||
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
|
||||
nmapInitOnce *sync.Once `gorm:"-"`
|
||||
}
|
||||
|
||||
func (a *Account) InitOnce() {
|
||||
a.nmapInitOnce = &sync.Once{}
|
||||
}
|
||||
|
||||
// this class is used by gorm only
|
||||
@@ -257,7 +265,6 @@ func (a *Account) GetPeerNetworkMap(
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
@@ -890,6 +897,8 @@ func (a *Account) Copy() *Account {
|
||||
NetworkRouters: networkRouters,
|
||||
NetworkResources: networkResources,
|
||||
Onboarding: a.Onboarding,
|
||||
NetworkMapCache: a.NetworkMapCache,
|
||||
nmapInitOnce: a.nmapInitOnce,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1049,14 +1058,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
rules := make([]*FirewallRule, 0)
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
|
||||
all, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
|
||||
all = &Group{}
|
||||
}
|
||||
|
||||
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
|
||||
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
@@ -1075,10 +1077,6 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
Protocol: string(rule.Protocol),
|
||||
}
|
||||
|
||||
if isAll {
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
|
||||
43
management/server/types/holder.go
Normal file
43
management/server/types/holder.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Holder struct {
|
||||
mu sync.RWMutex
|
||||
accounts map[string]*Account
|
||||
}
|
||||
|
||||
func NewHolder() *Holder {
|
||||
return &Holder{
|
||||
accounts: make(map[string]*Account),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Holder) GetAccount(id string) *Account {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.accounts[id]
|
||||
}
|
||||
|
||||
func (h *Holder) AddAccount(account *Account) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.accounts[account.Id] = account
|
||||
}
|
||||
|
||||
func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if acc, ok := h.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
account, err := accGetter(context.Background(), id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.accounts[id] = account
|
||||
return account, nil
|
||||
}
|
||||
58
management/server/types/networkmap.go
Normal file
58
management/server/types/networkmap.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
|
||||
if a.NetworkMapCache != nil {
|
||||
return
|
||||
}
|
||||
a.nmapInitOnce.Do(func() {
|
||||
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeers map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerAddedIncremental(peerId)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerDeleted(peerId)
|
||||
}
|
||||
|
||||
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
|
||||
if a.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
a.NetworkMapCache.UpdatePeer(peer)
|
||||
}
|
||||
|
||||
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
1069
management/server/types/networkmap_golden_test.go
Normal file
1069
management/server/types/networkmap_golden_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1927
management/server/types/networkmapbuilder.go
Normal file
1927
management/server/types/networkmapbuilder.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -595,7 +595,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
|
||||
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() {
|
||||
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
if oldUser.IsBlocked() != newUser.IsBlocked() {
|
||||
@@ -621,6 +621,35 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac
|
||||
})
|
||||
}
|
||||
|
||||
addedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, addedGroupIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get added groups for user %s update event: %v", oldUser.Id, err)
|
||||
}
|
||||
|
||||
for _, group := range addedGroups {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName,
|
||||
}
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupAddedToUser, meta)
|
||||
})
|
||||
}
|
||||
|
||||
removedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, removedGroupIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get removed groups for user %s update event: %v", oldUser.Id, err)
|
||||
}
|
||||
for _, group := range removedGroups {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName,
|
||||
}
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta)
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
@@ -667,9 +696,10 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
peersToExpire = userPeers
|
||||
}
|
||||
|
||||
var removedGroups, addedGroups []string
|
||||
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
|
||||
removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups)
|
||||
removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups)
|
||||
for _, peer := range userPeers {
|
||||
for _, groupID := range removedGroups {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
|
||||
@@ -685,7 +715,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
}
|
||||
|
||||
updateAccountPeers := len(userPeers) > 0
|
||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole)
|
||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, removedGroups, addedGroups, transaction)
|
||||
|
||||
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
||||
}
|
||||
@@ -935,7 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dnsDomain := am.GetDNSDomain(settings)
|
||||
dnsDomain := am.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
var peerIDs []string
|
||||
for _, peer := range peers {
|
||||
@@ -961,13 +991,14 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
peer.UserID, peer.ID, accountID,
|
||||
activity.PeerLoginExpired, peer.EventMeta(dnsDomain),
|
||||
)
|
||||
|
||||
am.networkMapController.OnPeerUpdated(accountID, peer)
|
||||
}
|
||||
|
||||
if len(peerIDs) != 0 {
|
||||
// this will trigger peer disconnect from the management service
|
||||
log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
|
||||
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
am.networkMapController.DisconnectPeers(ctx, peerIDs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1081,6 +1112,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
||||
|
||||
var addPeerRemovedEvents []func()
|
||||
var updateAccountPeers bool
|
||||
var userPeers []*nbpeer.Peer
|
||||
var targetUser *types.User
|
||||
var err error
|
||||
|
||||
@@ -1090,7 +1122,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
||||
return fmt.Errorf("failed to get user to delete: %w", err)
|
||||
}
|
||||
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
|
||||
userPeers, err = transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user peers: %w", err)
|
||||
}
|
||||
@@ -1113,6 +1145,17 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, peer := range userPeers {
|
||||
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
|
||||
}
|
||||
|
||||
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
|
||||
addPeerRemovedEvent()
|
||||
}
|
||||
|
||||
@@ -1161,7 +1161,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -1333,7 +1333,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
|
||||
func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
// account groups propagation is enabled
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
@@ -1357,9 +1357,9 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a new regular user should not update account peers and not send peer update
|
||||
@@ -1468,9 +1468,9 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
|
||||
peer4UpdMsg := updateManager.CreateChannel(context.Background(), peer4.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
|
||||
updateManager.CloseChannel(context.Background(), peer4.ID)
|
||||
})
|
||||
|
||||
// deleting user with linked peers should update account peers and send peer update
|
||||
@@ -1748,7 +1748,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
|
||||
}
|
||||
|
||||
func TestApproveUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1807,7 +1807,7 @@ func TestApproveUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRejectUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user