mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Compare commits
32 Commits
feature/us
...
use-pre-ex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6c5f5ead8 | ||
|
|
afbc0e65d7 | ||
|
|
31dd04e835 | ||
|
|
b6a8b1dbcd | ||
|
|
a29862182a | ||
|
|
ec4469f43d | ||
|
|
bcce1bf184 | ||
|
|
ac0d5ff9f3 | ||
|
|
54d896846b | ||
|
|
855fba8fac | ||
|
|
1802e51213 | ||
|
|
d56dfae9b8 | ||
|
|
6b930271fd | ||
|
|
059fc7c3a2 | ||
|
|
0371f529ca | ||
|
|
501fd93e47 | ||
|
|
727a4f0753 | ||
|
|
e6f7222034 | ||
|
|
bfc33a3f6f | ||
|
|
5ad4ae769a | ||
|
|
f84b606506 | ||
|
|
216d9f2ee8 | ||
|
|
57624203c9 | ||
|
|
24e031ab74 | ||
|
|
df8b8db068 | ||
|
|
3506ac4234 | ||
|
|
0c8f8a62c7 | ||
|
|
cbf9f2058e | ||
|
|
02f3105e48 | ||
|
|
5ee9c77e90 | ||
|
|
c832cef44c | ||
|
|
165988429c |
8
.editorconfig
Normal file
8
.editorconfig
Normal file
@@ -0,0 +1,8 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
|
||||
[*.go]
|
||||
indent_style = tab
|
||||
17
.github/workflows/golang-test-freebsd.yml
vendored
17
.github/workflows/golang-test-freebsd.yml
vendored
@@ -30,6 +30,17 @@ jobs:
|
||||
# -e - to faile on first error
|
||||
run: |
|
||||
set -e -x
|
||||
go build -o netbird client/main.go
|
||||
go test -timeout 5m -p 1 -failfast ./iface/...
|
||||
go test -timeout 5m -p 1 -failfast ./client/...
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
time go test -timeout 1m -failfast ./iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./signal/...
|
||||
time go test -timeout 1m -failfast ./util/...
|
||||
time go test -timeout 1m -failfast ./version/...
|
||||
|
||||
11
.github/workflows/release.yml
vendored
11
.github/workflows/release.yml
vendored
@@ -79,15 +79,8 @@ jobs:
|
||||
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso 386
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_386.syso
|
||||
- name: Generate windows syso arm
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm64.syso
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_amd64.syso
|
||||
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
@@ -170,7 +163,7 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/ui/resources_windows_amd64.syso
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
|
||||
@@ -151,10 +151,10 @@ jobs:
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
sleep 5
|
||||
docker-compose ps
|
||||
docker-compose logs --tail=20
|
||||
docker compose ps
|
||||
docker compose logs --tail=20
|
||||
|
||||
- name: test running containers
|
||||
run: |
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
|
||||
- name: Postgres run cleanup
|
||||
run: |
|
||||
docker-compose down --volumes --rmi all
|
||||
docker compose down --volumes --rmi all
|
||||
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
|
||||
|
||||
- name: run script with Zitadel CockroachDB
|
||||
|
||||
@@ -11,8 +11,6 @@ builds:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
tags:
|
||||
- legacy_appindicator
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-ui-windows
|
||||
|
||||
@@ -10,10 +10,12 @@
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&utm_medium=referral&utm_content=netbirdio/netbird&utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
|
||||
@@ -178,6 +178,21 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||
})
|
||||
}
|
||||
|
||||
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
|
||||
// domain names with random strings.
|
||||
func (a *Anonymizer) AnonymizeRoute(route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
||||
func isWellKnown(addr netip.Addr) bool {
|
||||
wellKnown := []string{
|
||||
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
@@ -13,6 +14,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Debugging commands",
|
||||
@@ -63,12 +66,17 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
@@ -84,7 +92,11 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := server.ParseLogLevel(args[0])
|
||||
@@ -113,7 +125,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
@@ -122,17 +138,20 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
|
||||
stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
|
||||
|
||||
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
if stateWasDown {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
if !initialLevelTrace {
|
||||
@@ -145,6 +164,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
@@ -162,21 +186,25 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if restoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
cmd.Println("Netbird down")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
@@ -186,16 +214,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
|
||||
@@ -37,6 +37,7 @@ const (
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,6 +70,7 @@ var (
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
@@ -91,12 +93,15 @@ func init() {
|
||||
oldDefaultConfigPathDir = "/etc/wiretrustee/"
|
||||
oldDefaultLogFileDir = "/var/log/wiretrustee/"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
|
||||
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
|
||||
|
||||
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
|
||||
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
|
||||
case "freebsd":
|
||||
defaultConfigPathDir = "/var/db/netbird/"
|
||||
}
|
||||
|
||||
defaultConfigPath = defaultConfigPathDir + "config.json"
|
||||
@@ -165,6 +170,8 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
|
||||
@@ -807,7 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = anonymizeRoute(a, route)
|
||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -843,21 +843,8 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
overview.Routes[i] = anonymizeRoute(a, route)
|
||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -71,6 +72,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -88,7 +90,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
return nil, nil
|
||||
}
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -69,6 +69,11 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
||||
if runtime.GOOS == "freebsd" {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
|
||||
@@ -94,7 +94,7 @@ func NewDefaultServer(
|
||||
|
||||
var dnsService service
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
dnsService = newServiceViaMemory(wgInterface)
|
||||
dnsService = NewServiceViaMemory(wgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
@@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
@@ -130,7 +130,7 @@ func NewDefaultServerIos(
|
||||
iosDnsManager IosDnsManager,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
return ds
|
||||
}
|
||||
|
||||
@@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
service: newServiceViaMemory(&mocWGIface{}),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type serviceViaMemory struct {
|
||||
type ServiceViaMemory struct {
|
||||
wgInterface WGIface
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP string
|
||||
@@ -22,8 +22,8 @@ type serviceViaMemory struct {
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
|
||||
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||
s := &serviceViaMemory{
|
||||
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
s := &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
@@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Listen() error {
|
||||
func (s *ServiceViaMemory) Listen() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
@@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Stop() {
|
||||
func (s *ServiceViaMemory) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
@@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() {
|
||||
s.listenerIsRunning = false
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
||||
func (s *ServiceViaMemory) DeregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimePort() int {
|
||||
func (s *ServiceViaMemory) RuntimePort() int {
|
||||
return s.runtimePort
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimeIP() string {
|
||||
func (s *ServiceViaMemory) RuntimeIP() string {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
|
||||
@@ -4,6 +4,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -17,9 +18,9 @@ import (
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
*upstreamResolverBase
|
||||
lIP net.IP
|
||||
lNet *net.IPNet
|
||||
iIndex int
|
||||
lIP net.IP
|
||||
lNet *net.IPNet
|
||||
interfaceName string
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
@@ -32,17 +33,11 @@ func newUpstreamResolver(
|
||||
) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
lIP: ip,
|
||||
lNet: net,
|
||||
iIndex: index,
|
||||
interfaceName: interfaceName,
|
||||
}
|
||||
ios.upstreamClient = ios
|
||||
|
||||
@@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
client := &dns.Client{}
|
||||
upstreamHost, _, err := net.SplitHostPort(upstream)
|
||||
if err != nil {
|
||||
log.Errorf("error while parsing upstream host: %s", err)
|
||||
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||
}
|
||||
|
||||
timeout := upstreamTimeout
|
||||
@@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
upstreamIP := net.ParseIP(upstreamHost)
|
||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
||||
log.Debugf("using private client to query upstream: %s", upstream)
|
||||
client = u.getClientPrivate(timeout)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||
return client.Exchange(r, upstream)
|
||||
}
|
||||
|
||||
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// This method is needed for iOS
|
||||
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
|
||||
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
LocalAddr: &net.UDPAddr{
|
||||
IP: u.lIP,
|
||||
IP: ip,
|
||||
Port: 0, // Let the OS pick a free port
|
||||
},
|
||||
Timeout: dialTimeout,
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var operr error
|
||||
fn := func(s uintptr) {
|
||||
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
|
||||
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
|
||||
}
|
||||
|
||||
if err := c.Control(fn); err != nil {
|
||||
@@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C
|
||||
client := &dns.Client{
|
||||
Dialer: dialer,
|
||||
}
|
||||
return client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func getInterfaceIndex(interfaceName string) (int, error) {
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
@@ -1069,7 +1070,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
@@ -65,7 +66,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
|
||||
}
|
||||
return client
|
||||
}
|
||||
@@ -383,9 +384,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
|
||||
if rt.IsDynamic() {
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
|
||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
||||
}
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -47,6 +48,8 @@ type Route struct {
|
||||
currentPeerKey string
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
resolverAddr string
|
||||
}
|
||||
|
||||
func NewRoute(
|
||||
@@ -55,6 +58,8 @@ func NewRoute(
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface *iface.WGIface,
|
||||
resolverAddr string,
|
||||
) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
@@ -63,6 +68,8 @@ func NewRoute(
|
||||
interval: interval,
|
||||
dynamicDomains: domainMap{},
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
resolverAddr: resolverAddr,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,9 +196,14 @@ func (r *Route) startResolver(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (r *Route) update(ctx context.Context) error {
|
||||
if resolved, err := r.resolveDomains(); err != nil {
|
||||
return fmt.Errorf("resolve domains: %w", err)
|
||||
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
|
||||
resolved, err := r.resolveDomains()
|
||||
if err != nil {
|
||||
if len(resolved) == 0 {
|
||||
return fmt.Errorf("resolve domains: %w", err)
|
||||
}
|
||||
log.Warnf("Failed to resolve domains: %v", err)
|
||||
}
|
||||
if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
|
||||
return fmt.Errorf("update dynamic routes: %w", err)
|
||||
}
|
||||
|
||||
@@ -223,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) {
|
||||
wg.Add(1)
|
||||
go func(domain domain.Domain) {
|
||||
defer wg.Done()
|
||||
ips, err := net.LookupIP(string(domain))
|
||||
|
||||
ips, err := r.getIPsFromResolver(domain)
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||
return
|
||||
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
|
||||
ips, err = net.LookupIP(string(domain))
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
|
||||
13
client/internal/routemanager/dynamic/route_generic.go
Normal file
13
client/internal/routemanager/dynamic/route_generic.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !ios
|
||||
|
||||
package dynamic
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||
return net.LookupIP(string(domain))
|
||||
}
|
||||
55
client/internal/routemanager/dynamic/route_ios.go
Normal file
55
client/internal/routemanager/dynamic/route_ios.go
Normal file
@@ -0,0 +1,55 @@
|
||||
//go:build ios
|
||||
|
||||
package dynamic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
const dialTimeout = 10 * time.Second
|
||||
|
||||
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while creating private client: %s", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
response, _, err := privateClient.Exchange(msg, r.resolverAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
|
||||
}
|
||||
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
|
||||
}
|
||||
|
||||
ips := make([]net.IP, 0)
|
||||
|
||||
for _, answ := range response.Answer {
|
||||
if aRecord, ok := answ.(*dns.A); ok {
|
||||
ips = append(ips, aRecord.A)
|
||||
}
|
||||
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
|
||||
ips = append(ips, aaaaRecord.AAAA)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
@@ -22,7 +22,7 @@ type Route struct {
|
||||
Interface *net.Interface
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
tab, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch RIB: %v", err)
|
||||
|
||||
@@ -427,7 +427,7 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
routes, err := GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
@@ -440,7 +440,7 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
routes, err := GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
|
||||
@@ -206,7 +206,7 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get v4 routes: %w", err)
|
||||
@@ -504,7 +504,7 @@ func getAddressFamily(prefix netip.Prefix) int {
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
if isLegacy() {
|
||||
return getRoutesFromTable()
|
||||
return GetRoutesFromTable()
|
||||
}
|
||||
return nil, ErrRoutingIsSeparate
|
||||
}
|
||||
|
||||
@@ -24,5 +24,5 @@ func EnableIPForwarding() error {
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return getRoutesFromTable()
|
||||
return GetRoutesFromTable()
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
|
||||
@@ -271,7 +271,14 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
}
|
||||
|
||||
routesMap := engine.GetClientRoutesWithNetID()
|
||||
routeSelector := engine.GetRouteManager().GetRouteSelector()
|
||||
routeManager := engine.GetRouteManager()
|
||||
if routeManager == nil {
|
||||
return nil, fmt.Errorf("could not get route manager")
|
||||
}
|
||||
routeSelector := routeManager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
return nil, fmt.Errorf("could not get route selector")
|
||||
}
|
||||
|
||||
var routes []*selectRoute
|
||||
for id, rt := range routesMap {
|
||||
|
||||
@@ -1828,8 +1828,9 @@ type DebugBundleRequest struct {
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
|
||||
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
|
||||
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
|
||||
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
|
||||
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
|
||||
}
|
||||
|
||||
func (x *DebugBundleRequest) Reset() {
|
||||
@@ -1878,6 +1879,13 @@ func (x *DebugBundleRequest) GetStatus() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *DebugBundleRequest) GetSystemInfo() bool {
|
||||
if x != nil {
|
||||
return x.SystemInfo
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type DebugBundleResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
@@ -2370,11 +2378,13 @@ var file_daemon_proto_rawDesc = []byte{
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c,
|
||||
0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a,
|
||||
0x02, 0x38, 0x01, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
|
||||
0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
|
||||
0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
|
||||
0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
|
||||
0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
|
||||
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22,
|
||||
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12,
|
||||
0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20,
|
||||
0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22,
|
||||
0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
|
||||
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65,
|
||||
|
||||
@@ -263,6 +263,7 @@ message Route {
|
||||
message DebugBundleRequest {
|
||||
bool anonymize = 1;
|
||||
string status = 2;
|
||||
bool systemInfo = 3;
|
||||
}
|
||||
|
||||
message DebugBundleResponse {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
@@ -6,16 +8,70 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
This debug bundle contains the following files:
|
||||
|
||||
status.txt: Anonymized status information of the NetBird client.
|
||||
client.log: Most recent, anonymized log file of the NetBird client.
|
||||
routes.txt: Anonymized system routes, if --system-info flag was provided.
|
||||
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
The files in this bundle have been anonymized to protect sensitive information. Here's how the anonymization was applied:
|
||||
|
||||
IP Addresses
|
||||
|
||||
IPv4 addresses are replaced with addresses starting from 192.51.100.0
|
||||
IPv6 addresses are replaced with addresses starting from 100::
|
||||
|
||||
IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.).
|
||||
Reoccuring IP addresses are replaced with the same anonymized address.
|
||||
|
||||
Note: The anonymized IP addresses in the status file do not match those in the log and routes files. However, the anonymized IP addresses are consistent within the status file and across the routes and log files.
|
||||
|
||||
Domains
|
||||
All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
|
||||
Reoccuring domain names are replaced with the same anonymized domain.
|
||||
|
||||
Routes
|
||||
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
|
||||
Network Interfaces
|
||||
The interfaces.txt file contains information about network interfaces, including:
|
||||
- Interface name
|
||||
- Interface index
|
||||
- MTU (Maximum Transmission Unit)
|
||||
- Flags
|
||||
- IP addresses associated with each interface
|
||||
|
||||
The IP addresses in the interfaces file are anonymized using the same process as described above. Interface names, indexes, MTUs, and flags are not anonymized.
|
||||
|
||||
Configuration
|
||||
The config.txt file contains anonymized configuration information of the NetBird client. Sensitive information such as private keys and SSH keys are excluded. The following fields are anonymized:
|
||||
- ManagementURL
|
||||
- AdminURL
|
||||
- NATExternalIPs
|
||||
- CustomDNSAddress
|
||||
|
||||
Other non-sensitive configuration options are included without anonymization.
|
||||
`
|
||||
|
||||
// DebugBundle creates a debug bundle and returns the location.
|
||||
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
||||
s.mutex.Lock()
|
||||
@@ -30,93 +86,211 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
return nil, fmt.Errorf("create zip file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := bundlePath.Close(); err != nil {
|
||||
log.Errorf("failed to close zip file: %v", err)
|
||||
if closeErr := bundlePath.Close(); closeErr != nil && err == nil {
|
||||
err = fmt.Errorf("close zip file: %w", closeErr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err2 := os.Remove(bundlePath.Name()); err2 != nil {
|
||||
log.Errorf("Failed to remove zip file: %v", err2)
|
||||
if removeErr := os.Remove(bundlePath.Name()); removeErr != nil {
|
||||
log.Errorf("Failed to remove zip file: %v", removeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
archive := zip.NewWriter(bundlePath)
|
||||
defer func() {
|
||||
if err := archive.Close(); err != nil {
|
||||
log.Errorf("failed to close archive writer: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if status := req.GetStatus(); status != "" {
|
||||
filename := "status.txt"
|
||||
if req.GetAnonymize() {
|
||||
filename = "status.anon.txt"
|
||||
}
|
||||
statusReader := strings.NewReader(status)
|
||||
if err := addFileToZip(archive, statusReader, filename); err != nil {
|
||||
return nil, fmt.Errorf("add status file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logFile, err := os.Open(s.logFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open log file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := logFile.Close(); err != nil {
|
||||
log.Errorf("failed to close original log file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
filename := "client.log.txt"
|
||||
var logReader io.Reader
|
||||
errChan := make(chan error, 1)
|
||||
if req.GetAnonymize() {
|
||||
filename = "client.anon.log.txt"
|
||||
var writer io.WriteCloser
|
||||
logReader, writer = io.Pipe()
|
||||
|
||||
go s.anonymize(logFile, writer, errChan)
|
||||
} else {
|
||||
logReader = logFile
|
||||
}
|
||||
if err := addFileToZip(archive, logReader, filename); err != nil {
|
||||
return nil, fmt.Errorf("add log file to zip: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
if err := s.createArchive(bundlePath, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
|
||||
}
|
||||
|
||||
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleRequest) error {
|
||||
archive := zip.NewWriter(bundlePath)
|
||||
if err := s.addReadme(req, archive); err != nil {
|
||||
return fmt.Errorf("add readme: %w", err)
|
||||
}
|
||||
|
||||
if err := s.addStatus(req, archive); err != nil {
|
||||
return fmt.Errorf("add status: %w", err)
|
||||
}
|
||||
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
status := s.statusRecorder.GetFullStatus()
|
||||
seedFromStatus(anonymizer, &status)
|
||||
|
||||
if err := s.addConfig(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add config: %w", err)
|
||||
}
|
||||
|
||||
if req.GetSystemInfo() {
|
||||
if err := s.addRoutes(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add routes: %w", err)
|
||||
}
|
||||
|
||||
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add interfaces: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.addLogfile(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add log file: %w", err)
|
||||
}
|
||||
|
||||
if err := archive.Close(); err != nil {
|
||||
return fmt.Errorf("close archive writer: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
|
||||
if req.GetAnonymize() {
|
||||
readmeReader := strings.NewReader(readmeContent)
|
||||
if err := addFileToZip(archive, readmeReader, "README.txt"); err != nil {
|
||||
return fmt.Errorf("add README file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addStatus(req *proto.DebugBundleRequest, archive *zip.Writer) error {
|
||||
if status := req.GetStatus(); status != "" {
|
||||
statusReader := strings.NewReader(status)
|
||||
if err := addFileToZip(archive, statusReader, "status.txt"); err != nil {
|
||||
return fmt.Errorf("add status file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addConfig(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
var configContent strings.Builder
|
||||
s.addCommonConfigFields(&configContent)
|
||||
|
||||
if req.GetAnonymize() {
|
||||
if s.config.ManagementURL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", anonymizer.AnonymizeURI(s.config.ManagementURL.String())))
|
||||
}
|
||||
if s.config.AdminURL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", anonymizer.AnonymizeURI(s.config.AdminURL.String())))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", anonymizeNATExternalIPs(s.config.NATExternalIPs, anonymizer)))
|
||||
if s.config.CustomDNSAddress != "" {
|
||||
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", anonymizer.AnonymizeString(s.config.CustomDNSAddress)))
|
||||
}
|
||||
} else {
|
||||
if s.config.ManagementURL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", s.config.ManagementURL.String()))
|
||||
}
|
||||
if s.config.AdminURL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", s.config.AdminURL.String()))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", s.config.NATExternalIPs))
|
||||
if s.config.CustomDNSAddress != "" {
|
||||
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", s.config.CustomDNSAddress))
|
||||
}
|
||||
}
|
||||
|
||||
// Add config content to zip file
|
||||
configReader := strings.NewReader(configContent.String())
|
||||
if err := addFileToZip(archive, configReader, "config.txt"); err != nil {
|
||||
return fmt.Errorf("add config file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
|
||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||
|
||||
// Add non-sensitive fields
|
||||
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", s.config.WgIface))
|
||||
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", s.config.WgPort))
|
||||
if s.config.NetworkMonitor != nil {
|
||||
configContent.WriteString(fmt.Sprintf("NetworkMonitor: %v\n", *s.config.NetworkMonitor))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("IFaceBlackList: %v\n", s.config.IFaceBlackList))
|
||||
configContent.WriteString(fmt.Sprintf("DisableIPv6Discovery: %v\n", s.config.DisableIPv6Discovery))
|
||||
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", s.config.RosenpassEnabled))
|
||||
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", s.config.RosenpassPermissive))
|
||||
if s.config.ServerSSHAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *s.config.ServerSSHAllowed))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", s.config.DisableAutoConnect))
|
||||
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", s.config.DNSRouteInterval))
|
||||
}
|
||||
|
||||
func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
if routes, err := systemops.GetRoutesFromTable(); err != nil {
|
||||
log.Errorf("Failed to get routes: %v", err)
|
||||
} else {
|
||||
// TODO: get routes including nexthop
|
||||
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
|
||||
routesReader := strings.NewReader(routesContent)
|
||||
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
|
||||
return fmt.Errorf("add routes file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get interfaces: %w", err)
|
||||
}
|
||||
|
||||
interfacesContent := formatInterfaces(interfaces, req.GetAnonymize(), anonymizer)
|
||||
interfacesReader := strings.NewReader(interfacesContent)
|
||||
if err := addFileToZip(archive, interfacesReader, "interfaces.txt"); err != nil {
|
||||
return fmt.Errorf("add interfaces file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) {
|
||||
logFile, err := os.Open(s.logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open log file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := writer.Close(); err != nil {
|
||||
log.Errorf("Failed to close writer: %v", err)
|
||||
if err := logFile.Close(); err != nil {
|
||||
log.Errorf("Failed to close original log file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var logReader io.Reader
|
||||
if req.GetAnonymize() {
|
||||
var writer *io.PipeWriter
|
||||
logReader, writer = io.Pipe()
|
||||
|
||||
go s.anonymize(logFile, writer, anonymizer)
|
||||
} else {
|
||||
logReader = logFile
|
||||
}
|
||||
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
|
||||
return fmt.Errorf("add log file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
|
||||
defer func() {
|
||||
// always nil
|
||||
_ = writer.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
line := anonymizer.AnonymizeString(scanner.Text())
|
||||
if _, err := writer.Write([]byte(line + "\n")); err != nil {
|
||||
errChan <- fmt.Errorf("write line to writer: %w", err)
|
||||
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
errChan <- fmt.Errorf("read line from scanner: %w", err)
|
||||
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -141,8 +315,22 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
||||
|
||||
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
|
||||
header := &zip.FileHeader{
|
||||
Name: filename,
|
||||
Method: zip.Deflate,
|
||||
Name: filename,
|
||||
Method: zip.Deflate,
|
||||
Modified: time.Now(),
|
||||
|
||||
CreatorVersion: 20, // Version 2.0
|
||||
ReaderVersion: 20, // Version 2.0
|
||||
Flags: 0x800, // UTF-8 filename
|
||||
}
|
||||
|
||||
// If the reader is a file, we can get more accurate information
|
||||
if f, ok := reader.(*os.File); ok {
|
||||
if stat, err := f.Stat(); err != nil {
|
||||
log.Tracef("Failed to get file stat for %s: %v", filename, err)
|
||||
} else {
|
||||
header.Modified = stat.ModTime()
|
||||
}
|
||||
}
|
||||
|
||||
writer, err := archive.CreateHeader(header)
|
||||
@@ -165,6 +353,13 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
|
||||
|
||||
for _, peer := range status.Peers {
|
||||
a.AnonymizeDomain(peer.FQDN)
|
||||
for route := range peer.GetRoutes() {
|
||||
a.AnonymizeRoute(route)
|
||||
}
|
||||
}
|
||||
|
||||
for route := range status.LocalPeerState.Routes {
|
||||
a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for _, nsGroup := range status.NSGroupStates {
|
||||
@@ -179,3 +374,113 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
|
||||
var ipv4Routes, ipv6Routes []netip.Prefix
|
||||
|
||||
// Separate IPv4 and IPv6 routes
|
||||
for _, route := range routes {
|
||||
if route.Addr().Is4() {
|
||||
ipv4Routes = append(ipv4Routes, route)
|
||||
} else {
|
||||
ipv6Routes = append(ipv6Routes, route)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort IPv4 and IPv6 routes separately
|
||||
sort.Slice(ipv4Routes, func(i, j int) bool {
|
||||
return ipv4Routes[i].Bits() > ipv4Routes[j].Bits()
|
||||
})
|
||||
sort.Slice(ipv6Routes, func(i, j int) bool {
|
||||
return ipv6Routes[i].Bits() > ipv6Routes[j].Bits()
|
||||
})
|
||||
|
||||
var builder strings.Builder
|
||||
|
||||
// Format IPv4 routes
|
||||
builder.WriteString("IPv4 Routes:\n")
|
||||
for _, route := range ipv4Routes {
|
||||
formatRoute(&builder, route, anonymize, anonymizer)
|
||||
}
|
||||
|
||||
// Format IPv6 routes
|
||||
builder.WriteString("\nIPv6 Routes:\n")
|
||||
for _, route := range ipv6Routes {
|
||||
formatRoute(&builder, route, anonymize, anonymizer)
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) {
|
||||
if anonymize {
|
||||
anonymizedIP := anonymizer.AnonymizeIP(route.Addr())
|
||||
builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits()))
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf("%s\n", route))
|
||||
}
|
||||
}
|
||||
|
||||
func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
|
||||
sort.Slice(interfaces, func(i, j int) bool {
|
||||
return interfaces[i].Name < interfaces[j].Name
|
||||
})
|
||||
|
||||
var builder strings.Builder
|
||||
builder.WriteString("Network Interfaces:\n")
|
||||
|
||||
for _, iface := range interfaces {
|
||||
builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
|
||||
builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
|
||||
builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
|
||||
builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
|
||||
} else {
|
||||
builder.WriteString(" Addresses:\n")
|
||||
for _, addr := range addrs {
|
||||
prefix, err := netip.ParsePrefix(addr.String())
|
||||
if err != nil {
|
||||
builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
|
||||
continue
|
||||
}
|
||||
ip := prefix.Addr()
|
||||
if anonymize {
|
||||
ip = anonymizer.AnonymizeIP(ip)
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
|
||||
anonymizedIPs := make([]string, len(ips))
|
||||
for i, ip := range ips {
|
||||
parts := strings.SplitN(ip, "/", 2)
|
||||
|
||||
ip1, err := netip.ParseAddr(parts[0])
|
||||
if err != nil {
|
||||
anonymizedIPs[i] = ip
|
||||
continue
|
||||
}
|
||||
ip1anon := anonymizer.AnonymizeIP(ip1)
|
||||
|
||||
if len(parts) == 2 {
|
||||
ip2, err := netip.ParseAddr(parts[1])
|
||||
if err != nil {
|
||||
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, parts[1])
|
||||
} else {
|
||||
ip2anon := anonymizer.AnonymizeIP(ip2)
|
||||
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, ip2anon)
|
||||
}
|
||||
} else {
|
||||
anonymizedIPs[i] = ip1anon.String()
|
||||
}
|
||||
}
|
||||
return anonymizedIPs
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
@@ -120,7 +121,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func Test_sysInfo(t *testing.T) {
|
||||
func Test_sysInfoMac(t *testing.T) {
|
||||
t.Skip("skipping darwin test")
|
||||
serialNum, prodName, manufacturer := sysInfo()
|
||||
if serialNum == "" {
|
||||
|
||||
@@ -21,6 +21,26 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type SysInfoGetter interface {
|
||||
GetSysInfo() SysInfo
|
||||
}
|
||||
|
||||
type SysInfoWrapper struct {
|
||||
si sysinfo.SysInfo
|
||||
}
|
||||
|
||||
func (s SysInfoWrapper) GetSysInfo() SysInfo {
|
||||
s.si.GetSysInfo()
|
||||
return SysInfo{
|
||||
ChassisSerial: s.si.Chassis.Serial,
|
||||
ProductSerial: s.si.Product.Serial,
|
||||
BoardSerial: s.si.Board.Serial,
|
||||
ProductName: s.si.Product.Name,
|
||||
BoardName: s.si.Board.Name,
|
||||
ProductVendor: s.si.Product.Vendor,
|
||||
}
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
info := _getInfo()
|
||||
@@ -45,7 +65,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
serialNum, prodName, manufacturer := sysInfo()
|
||||
si := SysInfoWrapper{}
|
||||
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
@@ -87,20 +108,36 @@ func _getInfo() string {
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var si sysinfo.SysInfo
|
||||
si.GetSysInfo()
|
||||
func sysInfo(si SysInfo) (string, string, string) {
|
||||
isascii := regexp.MustCompile("^[[:ascii:]]+$")
|
||||
serial := si.Chassis.Serial
|
||||
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
|
||||
serial = si.Product.Serial
|
||||
|
||||
serials := []string{si.ChassisSerial, si.ProductSerial}
|
||||
serial := ""
|
||||
|
||||
for _, s := range serials {
|
||||
if isascii.MatchString(s) {
|
||||
serial = s
|
||||
if s != "Default string" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!isascii.MatchString(serial)) && si.Board.Serial != "" {
|
||||
serial = si.Board.Serial
|
||||
|
||||
if serial == "" && isascii.MatchString(si.BoardSerial) {
|
||||
serial = si.BoardSerial
|
||||
}
|
||||
name := si.Product.Name
|
||||
if (!isascii.MatchString(name)) && si.Board.Name != "" {
|
||||
name = si.Board.Name
|
||||
|
||||
var name string
|
||||
for _, n := range []string{si.ProductName, si.BoardName} {
|
||||
if isascii.MatchString(n) {
|
||||
name = n
|
||||
break
|
||||
}
|
||||
}
|
||||
return serial, name, si.Product.Vendor
|
||||
|
||||
var manufacturer string
|
||||
if isascii.MatchString(si.ProductVendor) {
|
||||
manufacturer = si.ProductVendor
|
||||
}
|
||||
return serial, name, manufacturer
|
||||
}
|
||||
|
||||
12
client/system/sysinfo_linux.go
Normal file
12
client/system/sysinfo_linux.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package system
|
||||
|
||||
// SysInfo used to moc out the sysinfo getter
|
||||
type SysInfo struct {
|
||||
ChassisSerial string
|
||||
ProductSerial string
|
||||
BoardSerial string
|
||||
|
||||
ProductName string
|
||||
BoardName string
|
||||
ProductVendor string
|
||||
}
|
||||
198
client/system/sysinfo_linux_test.go
Normal file
198
client/system/sysinfo_linux_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package system
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_sysInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sysInfo SysInfo
|
||||
wantSerialNum string
|
||||
wantProdName string
|
||||
wantManufacturer string
|
||||
}{
|
||||
{
|
||||
name: "Test Case 1",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Empty Chassis Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Empty Chassis Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Fallback to Product Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Product serial",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Product serial",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Fallback to Product Serial with default string",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Non UTF-8 in Chassis Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "\x80",
|
||||
ProductSerial: "Product serial",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Product serial",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Non UTF-8 in Chassis Serial and Product Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "\x80",
|
||||
ProductSerial: "\x80",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "M80-G8013200245",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Non UTF-8 in Chassis Serial and Product Serial and BoardSerial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "\x80",
|
||||
ProductSerial: "\x80",
|
||||
BoardSerial: "\x80",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
|
||||
{
|
||||
name: "Empty Product Name",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "",
|
||||
BoardName: "boardname",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "boardname",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Invalid Product Name",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "\x80",
|
||||
BoardName: "boardname",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "boardname",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Invalid BoardName Name",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "\x80",
|
||||
BoardName: "\x80",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Invalid chars",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "\x80",
|
||||
ProductSerial: "\x80",
|
||||
BoardSerial: "\x80",
|
||||
ProductName: "\x80",
|
||||
BoardName: "\x80",
|
||||
ProductVendor: "\x80",
|
||||
},
|
||||
wantSerialNum: "",
|
||||
wantProdName: "",
|
||||
wantManufacturer: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo)
|
||||
if gotSerialNum != tt.wantSerialNum {
|
||||
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
|
||||
}
|
||||
if gotProdName != tt.wantProdName {
|
||||
t.Errorf("sysInfo() gotProdName = %v, want %v", gotProdName, tt.wantProdName)
|
||||
}
|
||||
if gotManufacturer != tt.wantManufacturer {
|
||||
t.Errorf("sysInfo() gotManufacturer = %v, want %v", gotManufacturer, tt.wantManufacturer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -22,8 +22,8 @@ import (
|
||||
"fyne.io/fyne/v2/app"
|
||||
"fyne.io/fyne/v2/dialog"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
"fyne.io/systray"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/getlantern/systray"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
|
||||
byteResp, err := pb.Marshal(message)
|
||||
if err != nil {
|
||||
log.Errorf("failed marshalling message %v", err)
|
||||
log.Errorf("failed marshalling message %v, %+v", err, message.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -14,14 +14,29 @@ type TextFormatter struct {
|
||||
levelDesc []string
|
||||
}
|
||||
|
||||
// SyslogFormatter formats logs into text
|
||||
type SyslogFormatter struct {
|
||||
levelDesc []string
|
||||
}
|
||||
|
||||
var validLevelDesc = []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}
|
||||
|
||||
|
||||
// NewTextFormatter create new MyTextFormatter instance
|
||||
func NewTextFormatter() *TextFormatter {
|
||||
return &TextFormatter{
|
||||
levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"},
|
||||
levelDesc: validLevelDesc,
|
||||
timestampFormat: time.RFC3339, // or RFC3339
|
||||
}
|
||||
}
|
||||
|
||||
// NewSyslogFormatter create new MySyslogFormatter instance
|
||||
func NewSyslogFormatter() *SyslogFormatter {
|
||||
return &SyslogFormatter{
|
||||
levelDesc: validLevelDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// Format renders a single log entry
|
||||
func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
var fields string
|
||||
@@ -49,3 +64,20 @@ func (f *TextFormatter) parseLevel(level logrus.Level) string {
|
||||
|
||||
return f.levelDesc[level]
|
||||
}
|
||||
|
||||
// Format renders a single log entry
|
||||
func (f *SyslogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
var fields string
|
||||
keys := make([]string, 0, len(entry.Data))
|
||||
for k, v := range entry.Data {
|
||||
if k == "source" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s%s\n", fields, entry.Message)), nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLogMessageFormat(t *testing.T) {
|
||||
func TestLogTextFormat(t *testing.T) {
|
||||
|
||||
someEntry := &logrus.Entry{
|
||||
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
|
||||
@@ -24,3 +24,20 @@ func TestLogMessageFormat(t *testing.T) {
|
||||
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
||||
assert.Regexp(t, expectedString, parsedString)
|
||||
}
|
||||
|
||||
func TestLogSyslogFormat(t *testing.T) {
|
||||
|
||||
someEntry := &logrus.Entry{
|
||||
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
|
||||
Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
|
||||
Level: 3,
|
||||
Message: "Some Message",
|
||||
}
|
||||
|
||||
formatter := NewSyslogFormatter()
|
||||
result, _ := formatter.Format(someEntry)
|
||||
|
||||
parsedString := string(result)
|
||||
expectedString := "^\\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] Some Message\\s+$"
|
||||
assert.Regexp(t, expectedString, parsedString)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,12 @@ func SetTextFormatter(logger *logrus.Logger) {
|
||||
logger.ReportCaller = true
|
||||
logger.AddHook(NewContextHook())
|
||||
}
|
||||
// SetSyslogFormatter set the text formatter for given logger.
|
||||
func SetSyslogFormatter(logger *logrus.Logger) {
|
||||
logger.Formatter = NewSyslogFormatter()
|
||||
logger.ReportCaller = true
|
||||
logger.AddHook(NewContextHook())
|
||||
}
|
||||
|
||||
// SetJSONFormatter set the JSON formatter for given logger.
|
||||
func SetJSONFormatter(logger *logrus.Logger) {
|
||||
|
||||
12
go.mod
12
go.mod
@@ -31,6 +31,7 @@ require (
|
||||
|
||||
require (
|
||||
fyne.io/fyne/v2 v2.1.4
|
||||
fyne.io/systray v1.11.0
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
@@ -38,7 +39,6 @@ require (
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/eko/gocache/v3 v3.1.1
|
||||
github.com/fsnotify/fsnotify v1.6.0
|
||||
github.com/getlantern/systray v1.2.1
|
||||
github.com/gliderlabs/ssh v0.3.4
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang/mock v1.6.0
|
||||
@@ -115,24 +115,17 @@ require (
|
||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v26.1.3+incompatible // indirect
|
||||
github.com/docker/docker v26.1.4+incompatible // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect
|
||||
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
|
||||
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
|
||||
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
|
||||
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
|
||||
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
|
||||
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
|
||||
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
|
||||
github.com/go-logr/logr v1.4.1 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-redis/redis/v8 v8.11.5 // indirect
|
||||
github.com/go-stack/stack v1.8.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
@@ -164,7 +157,6 @@ require (
|
||||
github.com/nxadm/tail v1.4.8 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||
github.com/pegasus-kv/thrift v0.13.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||
github.com/pion/mdns v0.0.12 // indirect
|
||||
|
||||
25
go.sum
25
go.sum
@@ -12,6 +12,8 @@ dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
|
||||
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||
fyne.io/fyne/v2 v2.1.4 h1:bt1+28++kAzRzPB0GM2EuSV4cnl8rXNX4cjfd8G06Rc=
|
||||
fyne.io/fyne/v2 v2.1.4/go.mod h1:p+E/Dh+wPW8JwR2DVcsZ9iXgR9ZKde80+Y+40Is54AQ=
|
||||
fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg=
|
||||
fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||
@@ -81,8 +83,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v26.1.3+incompatible h1:lLCzRbrVZrljpVNobJu1J2FHk8V0s4BawoZippkc+xo=
|
||||
github.com/docker/docker v26.1.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU=
|
||||
github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
@@ -111,18 +113,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 h1:NRUJuo3v3WGC/g5YiyF790gut6oQr5f3FBI88Wv0dx4=
|
||||
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520/go.mod h1:L+mq6/vvYHKjCX2oez0CgEAJmbq1fbb/oNJIWQkBybY=
|
||||
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 h1:6uJ+sZ/e03gkbqZ0kUG6mfKoqDb4XMAzMIwlajq19So=
|
||||
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7/go.mod h1:l+xpFBrCtDLpK9qNjxs+cHU6+BAdlBaxHqikB6Lku3A=
|
||||
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 h1:guBYzEaLz0Vfc/jv0czrr2z7qyzTOGC9hiQ0VC+hKjk=
|
||||
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7/go.mod h1:zx/1xUUeYPy3Pcmet8OSXLbF47l+3y6hIPpyLWoR9oc=
|
||||
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 h1:micT5vkcr9tOVk1FiH8SWKID8ultN44Z+yzd2y/Vyb0=
|
||||
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7/go.mod h1:dD3CgOrwlzca8ed61CsZouQS5h5jIzkK9ZWrTcf0s+o=
|
||||
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 h1:XYzSdCbkzOC0FDNrgJqGRo8PCMFOBFL9py72DRs7bmc=
|
||||
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55/go.mod h1:6mmzY2kW1TOOrVy+r41Za2MxXM+hhqTtY3oBKd2AgFA=
|
||||
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f h1:wrYrQttPS8FHIRSlsrcuKazukx/xqO/PpLZzZXsF+EA=
|
||||
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA=
|
||||
github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do=
|
||||
@@ -151,8 +141,6 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7
|
||||
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
@@ -337,8 +325,6 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
|
||||
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
@@ -368,8 +354,6 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ
|
||||
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
|
||||
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
|
||||
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c h1:rp5dCmg/yLR3mgFuSOe4oEnDDmGLROTvMragMUXpTQw=
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c/go.mod h1:X07ZCGwUbLaax7L0S3Tw4hpejzu63ZrrQiUe6W0hcy0=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pegasus-kv/thrift v0.13.0 h1:4ESwaNoHImfbHa9RUGJiJZ4hrxorihZHk5aarYwY8d4=
|
||||
@@ -609,7 +593,6 @@ golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
||||
@@ -9,7 +9,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
|
||||
@@ -71,7 +74,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ var (
|
||||
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
|
||||
}
|
||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -37,6 +39,7 @@ import (
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -135,8 +138,8 @@ type AccountManager interface {
|
||||
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
|
||||
CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
@@ -170,6 +173,8 @@ type DefaultAccountManager struct {
|
||||
userDeleteFromIDPEnabled bool
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
metrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// Settings represents Account settings structure that can be modified via API and Dashboard
|
||||
@@ -401,8 +406,17 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
|
||||
return a.Groups[groupID]
|
||||
}
|
||||
|
||||
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
||||
func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
|
||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||
func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
expandedPolicies PolicyRuleExpandedPeers,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
@@ -416,7 +430,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
}
|
||||
}
|
||||
|
||||
aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap)
|
||||
aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap, expandedPolicies)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
@@ -438,7 +452,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain)
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
zones = append(zones, peersCustomZone)
|
||||
}
|
||||
@@ -446,7 +460,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnect,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: routesUpdate,
|
||||
@@ -454,6 +468,60 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
|
||||
var merr *multierror.Error
|
||||
|
||||
if dnsDomain == "" {
|
||||
log.WithContext(ctx).Error("no dns domain is set, returning empty zone")
|
||||
return nbdns.CustomZone{}
|
||||
}
|
||||
|
||||
customZone := nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(dnsDomain),
|
||||
Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)),
|
||||
}
|
||||
|
||||
domainSuffix := "." + dnsDomain
|
||||
|
||||
var sb strings.Builder
|
||||
for _, peer := range a.Peers {
|
||||
if peer.DNSLabel == "" {
|
||||
merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name))
|
||||
continue
|
||||
}
|
||||
|
||||
sb.Grow(len(peer.DNSLabel) + len(domainSuffix))
|
||||
sb.WriteString(peer.DNSLabel)
|
||||
sb.WriteString(domainSuffix)
|
||||
|
||||
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
|
||||
Name: sb.String(),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: defaultTTL,
|
||||
RData: peer.IP.String(),
|
||||
})
|
||||
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
go func() {
|
||||
if merr != nil {
|
||||
log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr)
|
||||
}
|
||||
}()
|
||||
|
||||
return customZone
|
||||
}
|
||||
|
||||
// GetExpiredPeers returns peers that have been expired
|
||||
@@ -871,10 +939,18 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation,
|
||||
func BuildManager(
|
||||
ctx context.Context,
|
||||
store Store,
|
||||
peersUpdateManager *PeersUpdateManager,
|
||||
idpManager idp.Manager,
|
||||
singleAccountModeDomain string,
|
||||
dnsDomain string,
|
||||
eventStore activity.Store,
|
||||
geo *geolocation.Geolocation,
|
||||
userDeleteFromIDPEnabled bool,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
metrics telemetry.AppMetrics,
|
||||
) (*DefaultAccountManager, error) {
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -889,6 +965,7 @@ func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpd
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
metrics: metrics,
|
||||
}
|
||||
allAccounts := store.GetAllAccounts(ctx)
|
||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||
@@ -974,7 +1051,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -1025,7 +1102,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
|
||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -1124,7 +1201,7 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
|
||||
|
||||
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
|
||||
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -1584,7 +1661,7 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccountByUser(ctx, user.Id)
|
||||
@@ -1667,7 +1744,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, newAcc.Id)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
|
||||
alreadyUnlocked := false
|
||||
defer func() {
|
||||
if !alreadyUnlocked {
|
||||
@@ -1823,7 +1900,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err == nil {
|
||||
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, account.Id)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlockAccount()
|
||||
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
@@ -1843,7 +1920,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
return account, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
if domainAccount != nil {
|
||||
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, domainAccount.Id)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
if err != nil {
|
||||
@@ -1857,17 +1934,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
|
||||
}
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
|
||||
defer unlock()
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer accountUnlock()
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -1887,26 +1958,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey
|
||||
return peer, netMap, postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return status.Errorf(status.Unauthenticated, "peer not registered")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
||||
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer accountUnlock()
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account)
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err)
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1919,7 +1984,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -410,7 +411,9 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
validatedPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil, policyExpandedPeers)
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -2293,7 +2296,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
type TB interface {
|
||||
Cleanup(func())
|
||||
Helper()
|
||||
TempDir() string
|
||||
}
|
||||
|
||||
func createManager(t TB) (*DefaultAccountManager, error) {
|
||||
t.Helper()
|
||||
|
||||
store, err := createStore(t)
|
||||
@@ -2302,7 +2311,12 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2310,7 +2324,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func createStore(t *testing.T) (Store, error) {
|
||||
func createStore(t TB) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
||||
|
||||
@@ -56,6 +56,10 @@ type Config struct {
|
||||
func (c Config) GetAuthAudiences() []string {
|
||||
audiences := []string{c.HttpConfig.AuthAudience}
|
||||
|
||||
if c.HttpConfig.ExtraAuthAudience != "" {
|
||||
audiences = append(audiences, c.HttpConfig.ExtraAuthAudience)
|
||||
}
|
||||
|
||||
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
|
||||
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
|
||||
}
|
||||
@@ -90,6 +94,8 @@ type HttpServerConfig struct {
|
||||
OIDCConfigEndpoint string
|
||||
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
|
||||
IdpSignKeyRefreshEnabled bool
|
||||
// Extra audience
|
||||
ExtraAuthAudience string
|
||||
}
|
||||
|
||||
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -17,6 +17,50 @@ import (
|
||||
|
||||
const defaultTTL = 300
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
CustomZones sync.Map
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetCustomZone retrieves a cached custom zone
|
||||
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.CustomZones.Load(key); ok {
|
||||
return value.(*proto.CustomZone), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetCustomZone stores a custom zone in the cache
|
||||
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.CustomZones.Store(key, value)
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.NameServerGroups.Load(key); ok {
|
||||
return value.(*proto.NameServerGroup), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetNameServerGroup stores a name server group in the cache
|
||||
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.NameServerGroups.Store(key, value)
|
||||
}
|
||||
|
||||
type lookupMap map[string]struct{}
|
||||
|
||||
// DNSSettings defines dns settings at the account level
|
||||
@@ -36,7 +80,7 @@ func (d DNSSettings) Copy() DNSSettings {
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -58,7 +102,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -113,69 +157,73 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := &proto.CustomZone{Domain: zone.Domain}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
cacheKey := zone.Domain
|
||||
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
|
||||
} else {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
cache.SetCustomZone(cacheKey, protoZone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoNS := &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
}
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
|
||||
}
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
|
||||
if dnsDomain == "" {
|
||||
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
|
||||
return nbdns.CustomZone{}
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
}
|
||||
|
||||
customZone := nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(dnsDomain),
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if peer.DNSLabel == "" {
|
||||
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
|
||||
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: defaultTTL,
|
||||
RData: peer.IP.String(),
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
return customZone
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
|
||||
@@ -2,9 +2,14 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
@@ -195,7 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
@@ -320,3 +329,150 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: make([]nbdns.CustomZone, size),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
config.CustomZones[i] = nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("domain%d.com", i),
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: fmt.Sprintf("record%d", i),
|
||||
Type: 1,
|
||||
Class: "IN",
|
||||
TTL: 3600,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("group%d", i),
|
||||
Primary: i == 0,
|
||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||
SearchDomainsEnabled: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
Port: 53,
|
||||
NSType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000}
|
||||
|
||||
for _, size := range sizes {
|
||||
testData := generateTestData(size)
|
||||
|
||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||
cache := &DNSConfigCache{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
var cache DNSConfigCache
|
||||
|
||||
// Create two different configs
|
||||
config1 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group1",
|
||||
Name: "Group 1",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.org",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group2",
|
||||
Name: "Group 2",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache)
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache)
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache)
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||
}
|
||||
|
||||
// Verify that result2 is different from result1 and result3
|
||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
// Verify that the cache contains elements from both configs
|
||||
if _, exists := cache.GetCustomZone("example.com"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.com")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetCustomZone("example.org"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.org")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
||||
@@ -39,8 +39,8 @@ type FileStore struct {
|
||||
mux sync.Mutex `json:"-"`
|
||||
storeFile string `json:"-"`
|
||||
|
||||
// sync.Mutex indexed by accountID
|
||||
accountLocks sync.Map `json:"-"`
|
||||
// sync.Mutex indexed by resource ID
|
||||
resourceLocks sync.Map `json:"-"`
|
||||
globalAccountLock sync.Mutex `json:"-"`
|
||||
|
||||
metrics telemetry.AppMetrics `json:"-"`
|
||||
@@ -281,26 +281,26 @@ func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
|
||||
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID)
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{})
|
||||
mtx := value.(*sync.Mutex)
|
||||
mtx.Lock()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
|
||||
log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
|
||||
// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock
|
||||
// This method is still returns a write lock as file store can't handle read locks
|
||||
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
return s.AcquireAccountWriteLock(ctx, accountID)
|
||||
func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
return s.AcquireWriteLockByUID(ctx, uniqueID)
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@@ -30,6 +31,8 @@ func loadGeolocationDatabases(dataDir string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("geo location file %s not found , file will be downloaded", file)
|
||||
|
||||
switch file {
|
||||
case MMDBFileName:
|
||||
extractFunc := func(src string, dst string) error {
|
||||
|
||||
@@ -23,7 +23,7 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// GetGroup object of the peers
|
||||
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -77,7 +77,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID str
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -110,7 +110,7 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
|
||||
|
||||
// SaveGroup object of the peers
|
||||
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -245,7 +245,7 @@ func difference(a, b []string) []string {
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
@@ -359,7 +359,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID strin
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -413,7 +413,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
||||
@@ -156,7 +156,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
@@ -179,11 +179,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
@@ -194,12 +194,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
|
||||
|
||||
if !open {
|
||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
@@ -215,10 +215,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
@@ -226,17 +226,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
@@ -533,53 +533,46 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
|
||||
}
|
||||
}
|
||||
|
||||
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
remotePeers := []*proto.RemotePeerConfig{}
|
||||
for _, rPeer := range peers {
|
||||
fqdn := rPeer.FQDN(dnsName)
|
||||
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: fqdn,
|
||||
})
|
||||
}
|
||||
return remotePeers
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
|
||||
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
||||
|
||||
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
||||
|
||||
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
|
||||
|
||||
routesUpdate := toProtocolRoutes(networkMap.Routes)
|
||||
|
||||
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
|
||||
|
||||
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
|
||||
return &proto.SyncResponse{
|
||||
WiretrusteeConfig: wtConfig,
|
||||
PeerConfig: pConfig,
|
||||
RemotePeers: remotePeers,
|
||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
PeerConfig: pConfig,
|
||||
RemotePeers: remotePeers,
|
||||
OfflinePeers: offlinePeers,
|
||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||
Routes: routesUpdate,
|
||||
DNSConfig: dnsUpdate,
|
||||
FirewallRules: firewallRules,
|
||||
FirewallRulesIsEmpty: len(firewallRules) == 0,
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = allPeers
|
||||
response.NetworkMap.RemotePeers = allPeers
|
||||
response.RemotePeersIsEmpty = len(allPeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// IsHealthy indicates whether the service is healthy
|
||||
@@ -597,7 +590,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
} else {
|
||||
turnCredentials = nil
|
||||
}
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
|
||||
@@ -526,6 +526,43 @@ components:
|
||||
- revoked
|
||||
- auto_groups
|
||||
- usage_limit
|
||||
CreateSetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Setup Key name
|
||||
type: string
|
||||
example: Default key
|
||||
type:
|
||||
description: Setup key type, one-off for single time usage and reusable
|
||||
type: string
|
||||
example: reusable
|
||||
expires_in:
|
||||
description: Expiration time in seconds
|
||||
type: integer
|
||||
minimum: 86400
|
||||
maximum: 31536000
|
||||
example: 86400
|
||||
auto_groups:
|
||||
description: List of group IDs to auto-assign to peers registered with this key
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: "ch8i4ug6lnn4g9hqv7m0"
|
||||
usage_limit:
|
||||
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
type: integer
|
||||
example: 0
|
||||
ephemeral:
|
||||
description: Indicate that the peer will be ephemeral or not
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- expires_in
|
||||
- auto_groups
|
||||
- usage_limit
|
||||
PersonalAccessToken:
|
||||
type: object
|
||||
properties:
|
||||
@@ -1806,7 +1843,7 @@ paths:
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/SetupKeyRequest'
|
||||
$ref: '#/components/schemas/CreateSetupKeyRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A Setup Keys Object
|
||||
|
||||
@@ -254,6 +254,27 @@ type Country struct {
|
||||
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
|
||||
type CountryCode = string
|
||||
|
||||
// CreateSetupKeyRequest defines model for CreateSetupKeyRequest.
|
||||
type CreateSetupKeyRequest struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral *bool `json:"ephemeral,omitempty"`
|
||||
|
||||
// ExpiresIn Expiration time in seconds
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// Name Setup Key name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
}
|
||||
|
||||
// DNSSettings defines model for DNSSettings.
|
||||
type DNSSettings struct {
|
||||
// DisabledManagementGroups Groups whose DNS management is disabled
|
||||
@@ -1241,7 +1262,7 @@ type PostApiRoutesJSONRequestBody = RouteRequest
|
||||
type PutApiRoutesRouteIdJSONRequestBody = RouteRequest
|
||||
|
||||
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
|
||||
type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
|
||||
type PostApiSetupKeysJSONRequestBody = CreateSetupKeyRequest
|
||||
|
||||
// PutApiSetupKeysKeyIdJSONRequestBody defines body for PutApiSetupKeysKeyId for application/json ContentType.
|
||||
type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest
|
||||
|
||||
@@ -64,18 +64,28 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
|
||||
groupsInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
accessiblePeers, valid, err := h.getAccessibleAndValidStatus(ctx, account, peerID, dnsDomain, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
func (h *PeersHandler) getAccessibleAndValidStatus(ctx context.Context, account *server.Account, peerID string, dnsDomain string, peer *nbpeer.Peer) ([]api.AccessiblePeer, bool, error) {
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil, policyExpandedPeers)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
|
||||
return accessiblePeers, valid, nil
|
||||
}
|
||||
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
@@ -109,16 +119,12 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
accessiblePeers, valid, err := h.getAccessibleAndValidStatus(ctx, account, peerID, dnsDomain, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
|
||||
}
|
||||
@@ -194,9 +200,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
@@ -210,16 +214,6 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
}
|
||||
|
||||
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
|
||||
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
|
||||
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
|
||||
}
|
||||
|
||||
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
||||
for _, peer := range respBody {
|
||||
_, ok := approvedPeersMap[peer.Id]
|
||||
|
||||
@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
||||
return errors.New("invalid groups")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
a, err := am.Store.GetAccountByUser(ctx, userID)
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -16,8 +17,10 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -83,7 +86,7 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
defer func() {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, &Config{
|
||||
mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
@@ -399,32 +402,39 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
|
||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
||||
@@ -434,7 +444,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
return s, accountManager, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
|
||||
@@ -454,3 +464,165 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
|
||||
|
||||
return mgmtProto.NewManagementServiceClient(conn), conn, nil
|
||||
}
|
||||
func Test_SyncStatusRace(t *testing.T) {
|
||||
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
|
||||
t.Skip("Skipping on CI and Postgres store")
|
||||
}
|
||||
for i := 0; i < 500; i++ {
|
||||
t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) {
|
||||
testSyncStatusRace(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
func testSyncStatusRace(t *testing.T) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
|
||||
mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
}},
|
||||
TURNConfig: &TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{},
|
||||
Secret: "whatever",
|
||||
Turns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "turn:stun.wiretrustee.com:3468",
|
||||
}},
|
||||
},
|
||||
Signal: &Host{
|
||||
Proto: "http",
|
||||
URI: "signal.wiretrustee.com:10000",
|
||||
},
|
||||
Datadir: dir,
|
||||
HttpConfig: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, clientConn, err := createRawClient(mgmtAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
defer clientConn.Close()
|
||||
|
||||
// there are two peers already in the store, add two more
|
||||
peers, err := registerPeers(2, client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
serverKey, err := getServerKey(client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
concurrentPeerKey2 := peers[1]
|
||||
t.Log("Public key of concurrent peer: ", concurrentPeerKey2.PublicKey().String())
|
||||
|
||||
syncReq2 := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
|
||||
message2, err := encryption.EncryptMessage(*serverKey, *concurrentPeerKey2, syncReq2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx2, cancelFunc2 := context.WithCancel(context.Background())
|
||||
|
||||
//client.
|
||||
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: concurrentPeerKey2.PublicKey().String(),
|
||||
Body: message2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
resp2 := &mgmtProto.EncryptedMessage{}
|
||||
err = sync2.RecvMsg(resp2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peerWithInvalidStatus := peers[0]
|
||||
t.Log("Public key of peer with invalid status: ", peerWithInvalidStatus.PublicKey().String())
|
||||
|
||||
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
|
||||
message, err := encryption.EncryptMessage(*serverKey, *peerWithInvalidStatus, syncReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
|
||||
//client.
|
||||
sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
|
||||
Body: message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
// take the first registered peer as a base for the test. Total four.
|
||||
|
||||
resp := &mgmtProto.EncryptedMessage{}
|
||||
err = sync.RecvMsg(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
cancelFunc2()
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
cancelFunc()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
ctx, cancelFunc = context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
sync, err = client.Sync(ctx, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
|
||||
Body: message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
resp = &mgmtProto.EncryptedMessage{}
|
||||
err = sync.RecvMsg(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if !peer.Status.Connected {
|
||||
t.Fatal("Peer should be connected")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -541,8 +542,13 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating metrics: %v", err)
|
||||
}
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ type MockAccountManager struct {
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
|
||||
@@ -105,14 +105,14 @@ type MockAccountManager struct {
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP)
|
||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error {
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -95,7 +95,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
// SaveNameServerGroup saves nameserver group
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if nsGroupToSave == nil {
|
||||
@@ -130,7 +130,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
// ListNameServerGroups returns a list of nameserver groups from account
|
||||
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -762,7 +763,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (Store, error) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@@ -65,12 +66,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
||||
regularUser := !user.HasAdminPower() && !user.IsServiceUser
|
||||
|
||||
if regularUser && account.Settings.RegularUsersViewBlocked {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID {
|
||||
if regularUser && user.Id != peer.UserID {
|
||||
// only display peers that belong to the current user if the current user is not an admin
|
||||
continue
|
||||
}
|
||||
@@ -79,9 +82,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
peersMap[peer.ID] = p
|
||||
}
|
||||
|
||||
if !regularUser {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
for _, peer := range peers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap, policyExpandedPeers)
|
||||
for _, p := range aclPeers {
|
||||
peersMap[p.ID] = p
|
||||
}
|
||||
@@ -150,7 +158,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -272,7 +280,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -316,7 +324,9 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil, policyExpandedPeers), nil
|
||||
}
|
||||
|
||||
// GetPeerNetwork returns the Network for a given peer
|
||||
@@ -356,7 +366,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
@@ -380,7 +390,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||
// and the peer disconnects with a timeout and tries to register again.
|
||||
// We just check if this machine has been registered before and reject the second registration.
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
@@ -529,7 +539,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap)
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers)
|
||||
return newPeer, networkMap, postureChecks, nil
|
||||
}
|
||||
|
||||
@@ -585,7 +597,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers), postureChecks, nil
|
||||
}
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
@@ -620,7 +634,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
@@ -732,7 +746,9 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers), postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
|
||||
@@ -811,7 +827,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
@@ -846,7 +862,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -885,8 +901,9 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap, policyExpandedPeers)
|
||||
for _, aclPeer := range aclPeers {
|
||||
if aclPeer.ID == peerID {
|
||||
return peer, nil
|
||||
@@ -908,22 +925,45 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
if am.metrics != nil {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
|
||||
}
|
||||
}()
|
||||
|
||||
peers := account.GetPeers()
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
dnsCache := &DNSConfigCache{}
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
expandedPolicies := account.GetPolicyExpandedPeers()
|
||||
for _, peer := range peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
|
||||
update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, p)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), expandedPolicies)
|
||||
update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -241,7 +240,7 @@ func (p *Peer) FQDN(dnsDomain string) string {
|
||||
if dnsDomain == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
|
||||
return p.DNSLabel + "." + dnsDomain
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the peer
|
||||
|
||||
31
management/server/peer/peer_test.go
Normal file
31
management/server/peer/peer_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// FQDNOld is the original implementation for benchmarking purposes
|
||||
func (p *Peer) FQDNOld(dnsDomain string) string {
|
||||
if dnsDomain == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
|
||||
}
|
||||
|
||||
func BenchmarkFQDN(b *testing.B) {
|
||||
p := &Peer{DNSLabel: "test-peer"}
|
||||
dnsDomain := "example.com"
|
||||
|
||||
b.Run("Old", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.FQDNOld(dnsDomain)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("New", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.FQDN(dnsDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -2,15 +2,26 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestPeer_LoginExpired(t *testing.T) {
|
||||
@@ -633,3 +644,354 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
|
||||
b.Helper()
|
||||
|
||||
manager, err := createManager(b)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[regularUser] = &User{
|
||||
Id: regularUser,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
|
||||
// Create peers
|
||||
for i := 0; i < peers; i++ {
|
||||
peerKey, _ := wgtypes.GeneratePrivateKey()
|
||||
peer := &nbpeer.Peer{
|
||||
ID: fmt.Sprintf("peer-%d", i),
|
||||
DNSLabel: fmt.Sprintf("peer-%d", i),
|
||||
Key: peerKey.PublicKey().String(),
|
||||
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
UserID: regularUser,
|
||||
}
|
||||
account.Peers[peer.ID] = peer
|
||||
}
|
||||
|
||||
// Create groups and policies
|
||||
account.Policies = make([]*Policy, 0, groups)
|
||||
for i := 0; i < groups; i++ {
|
||||
groupID := fmt.Sprintf("group-%d", i)
|
||||
group := &nbgroup.Group{
|
||||
ID: groupID,
|
||||
Name: fmt.Sprintf("Group %d", i),
|
||||
}
|
||||
for j := 0; j < peers/groups; j++ {
|
||||
peerIndex := i*(peers/groups) + j
|
||||
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
|
||||
}
|
||||
account.Groups[groupID] = group
|
||||
|
||||
// Create a policy for this group
|
||||
policy := &Policy{
|
||||
ID: fmt.Sprintf("policy-%d", i),
|
||||
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: fmt.Sprintf("rule-%d", i),
|
||||
Name: fmt.Sprintf("Rule for Group %d", i),
|
||||
Enabled: true,
|
||||
Sources: []string{groupID},
|
||||
Destinations: []string{groupID},
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
account.Policies = append(account.Policies, policy)
|
||||
}
|
||||
|
||||
account.PostureChecks = []*posture.Checks{
|
||||
{
|
||||
ID: "PostureChecksAll",
|
||||
Name: "All",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
return manager, accountID, regularUser, nil
|
||||
}
|
||||
|
||||
func BenchmarkGetPeers(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
}{
|
||||
{"Small", 50, 5},
|
||||
{"Medium", 500, 10},
|
||||
{"Large", 5000, 20},
|
||||
{"Small single", 50, 1},
|
||||
{"Medium single", 500, 1},
|
||||
{"Large 5", 5000, 5},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := manager.GetPeers(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
b.Fatalf("GetPeers failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
}{
|
||||
{"Small", 50, 5},
|
||||
{"Medium", 500, 10},
|
||||
{"Large", 5000, 20},
|
||||
{"Small single", 50, 1},
|
||||
{"Medium single", 500, 1},
|
||||
{"Large 5", 5000, 5},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
|
||||
b.ReportMetric(0, "ns/op")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToSyncResponse(t *testing.T) {
|
||||
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Signal: &Host{
|
||||
Proto: "https",
|
||||
URI: "signal.uri",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||
TURNConfig: &TURNConfig{
|
||||
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
},
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer-key",
|
||||
DNSLabel: "peer1",
|
||||
SSHKey: "peer1-ssh-key",
|
||||
}
|
||||
turnCredentials := &TURNCredentials{
|
||||
Username: "turn-user",
|
||||
Password: "turn-pass",
|
||||
}
|
||||
networkMap := &NetworkMap{
|
||||
Network: &Network{Net: *ipnet, Serial: 1000},
|
||||
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||
Routes: []*nbroute.Route{
|
||||
{
|
||||
ID: "route1",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Domains: domainList,
|
||||
KeepRoute: true,
|
||||
NetID: "route1",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
DNSConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
{
|
||||
ID: "ns1",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Groups: []string{"group1"},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
},
|
||||
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||
},
|
||||
FirewallRules: []*FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||
},
|
||||
}
|
||||
dnsName := "example.com"
|
||||
checks := []*posture.Checks{
|
||||
{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
|
||||
response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
|
||||
|
||||
assert.NotNil(t, response)
|
||||
// assert peer config
|
||||
assert.Equal(t, "192.168.1.1/24", response.PeerConfig.Address)
|
||||
assert.Equal(t, "peer1.example.com", response.PeerConfig.Fqdn)
|
||||
assert.Equal(t, true, response.PeerConfig.SshConfig.SshEnabled)
|
||||
// assert wiretrustee config
|
||||
assert.Equal(t, "signal.uri", response.WiretrusteeConfig.Signal.Uri)
|
||||
assert.Equal(t, proto.HostConfig_HTTPS, response.WiretrusteeConfig.Signal.GetProtocol())
|
||||
assert.Equal(t, "stun.uri", response.WiretrusteeConfig.Stuns[0].Uri)
|
||||
assert.Equal(t, "turn.uri", response.WiretrusteeConfig.Turns[0].HostConfig.GetUri())
|
||||
assert.Equal(t, "turn-user", response.WiretrusteeConfig.Turns[0].User)
|
||||
assert.Equal(t, "turn-pass", response.WiretrusteeConfig.Turns[0].Password)
|
||||
// assert RemotePeers
|
||||
assert.Equal(t, 1, len(response.RemotePeers))
|
||||
assert.Equal(t, "192.168.1.2/32", response.RemotePeers[0].AllowedIps[0])
|
||||
assert.Equal(t, "peer2-key", response.RemotePeers[0].WgPubKey)
|
||||
assert.Equal(t, "peer2.example.com", response.RemotePeers[0].GetFqdn())
|
||||
assert.Equal(t, false, response.RemotePeers[0].GetSshConfig().GetSshEnabled())
|
||||
assert.Equal(t, []byte("peer2-ssh-key"), response.RemotePeers[0].GetSshConfig().GetSshPubKey())
|
||||
// assert network map
|
||||
assert.Equal(t, uint64(1000), response.NetworkMap.Serial)
|
||||
assert.Equal(t, "192.168.1.1/24", response.NetworkMap.PeerConfig.Address)
|
||||
assert.Equal(t, "peer1.example.com", response.NetworkMap.PeerConfig.Fqdn)
|
||||
assert.Equal(t, true, response.NetworkMap.PeerConfig.SshConfig.SshEnabled)
|
||||
// assert network map RemotePeers
|
||||
assert.Equal(t, 1, len(response.NetworkMap.RemotePeers))
|
||||
assert.Equal(t, "192.168.1.2/32", response.NetworkMap.RemotePeers[0].AllowedIps[0])
|
||||
assert.Equal(t, "peer2-key", response.NetworkMap.RemotePeers[0].WgPubKey)
|
||||
assert.Equal(t, "peer2.example.com", response.NetworkMap.RemotePeers[0].GetFqdn())
|
||||
assert.Equal(t, []byte("peer2-ssh-key"), response.NetworkMap.RemotePeers[0].GetSshConfig().GetSshPubKey())
|
||||
// assert network map OfflinePeers
|
||||
assert.Equal(t, 1, len(response.NetworkMap.OfflinePeers))
|
||||
assert.Equal(t, "192.168.1.3/32", response.NetworkMap.OfflinePeers[0].AllowedIps[0])
|
||||
assert.Equal(t, "peer3-key", response.NetworkMap.OfflinePeers[0].WgPubKey)
|
||||
assert.Equal(t, "peer3.example.com", response.NetworkMap.OfflinePeers[0].GetFqdn())
|
||||
assert.Equal(t, []byte("peer3-ssh-key"), response.NetworkMap.OfflinePeers[0].GetSshConfig().GetSshPubKey())
|
||||
// assert network map Routes
|
||||
assert.Equal(t, 1, len(response.NetworkMap.Routes))
|
||||
assert.Equal(t, "10.0.0.0/24", response.NetworkMap.Routes[0].Network)
|
||||
assert.Equal(t, "route1", response.NetworkMap.Routes[0].ID)
|
||||
assert.Equal(t, "peer1", response.NetworkMap.Routes[0].Peer)
|
||||
assert.Equal(t, "example.com", response.NetworkMap.Routes[0].Domains[0])
|
||||
assert.Equal(t, true, response.NetworkMap.Routes[0].KeepRoute)
|
||||
assert.Equal(t, true, response.NetworkMap.Routes[0].Masquerade)
|
||||
assert.Equal(t, int64(9999), response.NetworkMap.Routes[0].Metric)
|
||||
assert.Equal(t, int64(1), response.NetworkMap.Routes[0].NetworkType)
|
||||
assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID)
|
||||
// assert network map DNSConfig
|
||||
assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable)
|
||||
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones))
|
||||
assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups))
|
||||
// assert network map DNSConfig.CustomZones
|
||||
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Domain)
|
||||
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones[0].Records))
|
||||
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Name)
|
||||
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Type)
|
||||
assert.Equal(t, "IN", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Class)
|
||||
assert.Equal(t, int64(60), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].TTL)
|
||||
assert.Equal(t, "100.64.0.1", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData)
|
||||
// assert network map DNSConfig.NameServerGroups
|
||||
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].Primary)
|
||||
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].SearchDomainsEnabled)
|
||||
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.NameServerGroups[0].Domains[0])
|
||||
assert.Equal(t, "8.8.8.8", response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetIP())
|
||||
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetNSType())
|
||||
assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort())
|
||||
// assert network map Firewall
|
||||
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
|
||||
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
|
||||
assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
|
||||
assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
|
||||
assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
|
||||
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
|
||||
// assert posture checks
|
||||
assert.Equal(t, 1, len(response.Checks))
|
||||
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
||||
}
|
||||
|
||||
@@ -212,21 +212,20 @@ type FirewallRule struct {
|
||||
// getPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
|
||||
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}, expandedPolicies PolicyRuleExpandedPeers) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
for n, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, expandedPolicies[policy.ID][n].sourcePeers, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, expandedPolicies[policy.ID][n].destinationPeers, peerID, nil, validatedPeersMap)
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
@@ -290,8 +289,8 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
@@ -315,7 +314,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -343,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -371,7 +370,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
// DeletePolicy from the store
|
||||
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -398,7 +397,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
|
||||
// ListPolicies from the store
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -491,38 +490,26 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||
//
|
||||
// Important: Posture checks are applicable only to source group peers,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, peerMap peerMap, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
group, ok := account.Groups[g]
|
||||
if !ok {
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(peerMap))
|
||||
for _, peer := range peerMap {
|
||||
|
||||
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range group.Peers {
|
||||
peer, ok := account.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
filteredPeers = append(filteredPeers, peer)
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
filteredPeers = append(filteredPeers, peer)
|
||||
}
|
||||
return filteredPeers, peerInGroups
|
||||
}
|
||||
@@ -535,7 +522,7 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
|
||||
}
|
||||
|
||||
for _, postureChecksID := range sourcePostureChecksID {
|
||||
postureChecks := getPostureChecks(a, postureChecksID)
|
||||
postureChecks := a.getPostureChecks(postureChecksID)
|
||||
if postureChecks == nil {
|
||||
continue
|
||||
}
|
||||
@@ -553,11 +540,53 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
|
||||
return true
|
||||
}
|
||||
|
||||
func getPostureChecks(account *Account, postureChecksID string) *posture.Checks {
|
||||
for _, postureChecks := range account.PostureChecks {
|
||||
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||
for _, postureChecks := range a.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
return postureChecks
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type expandedRuleGroups struct {
|
||||
sourcePeers peerMap
|
||||
destinationPeers peerMap
|
||||
}
|
||||
|
||||
type peerMap map[string]*nbpeer.Peer
|
||||
|
||||
// PolicyRuleExpandedPeers is a map with the peers of each policy rule source and destination groups
|
||||
type PolicyRuleExpandedPeers map[string]map[int]expandedRuleGroups
|
||||
|
||||
// GetPolicyExpandedPeers returns a map with the peers of each policy rule source and destination groups
|
||||
func (a *Account) GetPolicyExpandedPeers() PolicyRuleExpandedPeers {
|
||||
policyMap := make(PolicyRuleExpandedPeers)
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
ruleMap := make(map[int]expandedRuleGroups)
|
||||
policyMap[policy.ID] = ruleMap
|
||||
for ruleID, rule := range policy.Rules {
|
||||
policyMap[policy.ID][ruleID] = expandedRuleGroups{
|
||||
sourcePeers: make(peerMap),
|
||||
destinationPeers: make(peerMap),
|
||||
}
|
||||
a.processGroups(rule.Sources, policyMap[policy.ID][ruleID].sourcePeers)
|
||||
a.processGroups(rule.Destinations, policyMap[policy.ID][ruleID].destinationPeers)
|
||||
}
|
||||
}
|
||||
return policyMap
|
||||
}
|
||||
|
||||
func (a *Account) processGroups(groupIDs []string, peerMap peerMap) {
|
||||
for _, gid := range groupIDs {
|
||||
for _, pid := range a.Groups[gid].Peers {
|
||||
p, ok := a.Peers[pid]
|
||||
if ok {
|
||||
peerMap[pid] = p
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,15 +143,17 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("check that all peers get map", func(t *testing.T) {
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
for _, p := range account.Peers {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers, policyExpandedPeers)
|
||||
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check first peer map details", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers)
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, 7)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
@@ -387,7 +389,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("check first peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
@@ -415,7 +418,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
@@ -445,7 +449,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
account.Policies[1].Rules[0].Bidirectional = false
|
||||
|
||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
@@ -466,7 +471,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
@@ -661,9 +667,10 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
approvedPeers[p] = struct{}{}
|
||||
}
|
||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||
// will establish a connection with all source peers satisfying the NB posture check.
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -673,7 +680,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 1)
|
||||
expectedFirewallRules := []*FirewallRule{
|
||||
@@ -689,7 +696,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -699,7 +706,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@@ -711,22 +718,23 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
t.Run("verify peer's network map with modified group peer list", func(t *testing.T) {
|
||||
// Removing peerB as the part of destination group Swarm
|
||||
account.Groups["GroupSwarm"].Peers = []string{"peerA", "peerD", "peerE", "peerG", "peerH"}
|
||||
policyExpandedPeers := account.GetPolicyExpandedPeers()
|
||||
|
||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||
|
||||
@@ -738,17 +746,18 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// Removing peerF as the part of source group All
|
||||
account.Groups["GroupAll"].Peers = []string{"peerB", "peerA", "peerD", "peerC", "peerG", "peerH"}
|
||||
policyExpandedPeers = account.GetPolicyExpandedPeers()
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, 3)
|
||||
assert.Len(t, firewallRules, 3)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers, policyExpandedPeers)
|
||||
assert.Len(t, peers, 5)
|
||||
// assert peers from Group Swarm
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
@@ -15,7 +15,7 @@ const (
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -42,7 +42,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -89,7 +89,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -121,7 +121,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
|
||||
// GetRoute gets a route object from account and route IDs
|
||||
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -126,7 +126,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
|
||||
|
||||
// CreateRoute creates and saves a new route
|
||||
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -214,7 +214,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
|
||||
// SaveRoute saves route
|
||||
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if routeToSave == nil {
|
||||
@@ -283,7 +283,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -311,7 +311,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -1233,7 +1234,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (Store, error) {
|
||||
|
||||
@@ -210,7 +210,7 @@ func Hash(s string) uint32 {
|
||||
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
||||
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
keyDuration := DefaultSetupKeyDuration
|
||||
@@ -256,7 +256,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if keyToSave == nil {
|
||||
@@ -328,7 +328,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
|
||||
// ListSetupKeys returns a list of all setup keys of the account
|
||||
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -360,7 +360,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
||||
@@ -40,7 +40,7 @@ const (
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
type SqlStore struct {
|
||||
db *gorm.DB
|
||||
accountLocks sync.Map
|
||||
resourceLocks sync.Map
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
@@ -98,33 +98,35 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqlStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring write lock for account %s", accountID)
|
||||
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released write lock for account %s in %v", accountID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqlStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring read lock for account %s", accountID)
|
||||
// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.RLock()
|
||||
|
||||
unlock = func() {
|
||||
mtx.RUnlock()
|
||||
log.WithContext(ctx).Tracef("released read lock for account %s in %v", accountID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
|
||||
@@ -49,10 +49,10 @@ type Store interface {
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
GetInstallationID() string
|
||||
SaveInstallationID(ctx context.Context, ID string) error
|
||||
// AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock
|
||||
AcquireAccountWriteLock(ctx context.Context, accountID string) func()
|
||||
// AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock
|
||||
AcquireAccountReadLock(ctx context.Context, accountID string) func()
|
||||
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
||||
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
||||
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||
AcquireGlobalLock(ctx context.Context) func()
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
|
||||
69
management/server/telemetry/accountmanager_metrics.go
Normal file
69
management/server/telemetry/accountmanager_metrics.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// AccountManagerMetrics represents all metrics related to the AccountManager
|
||||
type AccountManagerMetrics struct {
|
||||
ctx context.Context
|
||||
updateAccountPeersDurationMs metric.Float64Histogram
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
}
|
||||
|
||||
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||
func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*AccountManagerMetrics, error) {
|
||||
updateAccountPeersDurationMs, err := meter.Float64Histogram("management.account.update.account.peers.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getPeerNetworkMapDurationMs, err := meter.Float64Histogram("management.account.get.peer.network.map.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
0.1, 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
networkMapObjectCount, err := meter.Int64Histogram("management.account.network.map.object.count",
|
||||
metric.WithUnit("objects"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
50, 100, 200, 500, 1000, 2500, 5000, 10000,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountManagerMetrics{
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// CountUpdateAccountPeersDuration counts the duration of updating account peers
|
||||
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
|
||||
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// CountGetPeerNetworkMapDuration counts the duration of getting the peer network map
|
||||
func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration time.Duration) {
|
||||
metrics.getPeerNetworkMapDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// CountNetworkMapObjects counts the number of network map objects
|
||||
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
|
||||
metrics.networkMapObjectCount.Record(metrics.ctx, count)
|
||||
}
|
||||
@@ -20,14 +20,15 @@ const defaultEndpoint = "/metrics"
|
||||
|
||||
// MockAppMetrics mocks the AppMetrics interface
|
||||
type MockAppMetrics struct {
|
||||
GetMeterFunc func() metric2.Meter
|
||||
CloseFunc func() error
|
||||
ExposeFunc func(ctx context.Context, port int, endpoint string) error
|
||||
IDPMetricsFunc func() *IDPMetrics
|
||||
HTTPMiddlewareFunc func() *HTTPMiddleware
|
||||
GRPCMetricsFunc func() *GRPCMetrics
|
||||
StoreMetricsFunc func() *StoreMetrics
|
||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||
GetMeterFunc func() metric2.Meter
|
||||
CloseFunc func() error
|
||||
ExposeFunc func(ctx context.Context, port int, endpoint string) error
|
||||
IDPMetricsFunc func() *IDPMetrics
|
||||
HTTPMiddlewareFunc func() *HTTPMiddleware
|
||||
GRPCMetricsFunc func() *GRPCMetrics
|
||||
StoreMetricsFunc func() *StoreMetrics
|
||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
|
||||
}
|
||||
|
||||
// GetMeter mocks the GetMeter function of the AppMetrics interface
|
||||
@@ -94,6 +95,14 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccountManagerMetrics mocks the MockAppMetrics function of the AccountManagerMetrics interface
|
||||
func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||
if mock.AddAccountManagerMetricsFunc != nil {
|
||||
return mock.AddAccountManagerMetricsFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppMetrics is metrics interface
|
||||
type AppMetrics interface {
|
||||
GetMeter() metric2.Meter
|
||||
@@ -104,19 +113,21 @@ type AppMetrics interface {
|
||||
GRPCMetrics() *GRPCMetrics
|
||||
StoreMetrics() *StoreMetrics
|
||||
UpdateChannelMetrics() *UpdateChannelMetrics
|
||||
AccountManagerMetrics() *AccountManagerMetrics
|
||||
}
|
||||
|
||||
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
|
||||
type defaultAppMetrics struct {
|
||||
// Meter can be used by different application parts to create counters and measure things
|
||||
Meter metric2.Meter
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
idpMetrics *IDPMetrics
|
||||
httpMiddleware *HTTPMiddleware
|
||||
grpcMetrics *GRPCMetrics
|
||||
storeMetrics *StoreMetrics
|
||||
updateChannelMetrics *UpdateChannelMetrics
|
||||
Meter metric2.Meter
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
idpMetrics *IDPMetrics
|
||||
httpMiddleware *HTTPMiddleware
|
||||
grpcMetrics *GRPCMetrics
|
||||
storeMetrics *StoreMetrics
|
||||
updateChannelMetrics *UpdateChannelMetrics
|
||||
accountManagerMetrics *AccountManagerMetrics
|
||||
}
|
||||
|
||||
// IDPMetrics returns metrics for the idp package
|
||||
@@ -144,6 +155,11 @@ func (appMetrics *defaultAppMetrics) UpdateChannelMetrics() *UpdateChannelMetric
|
||||
return appMetrics.updateChannelMetrics
|
||||
}
|
||||
|
||||
// AccountManagerMetrics returns metrics for the account manager
|
||||
func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||
return appMetrics.accountManagerMetrics
|
||||
}
|
||||
|
||||
// Close stop application metrics HTTP handler and closes listener.
|
||||
func (appMetrics *defaultAppMetrics) Close() error {
|
||||
if appMetrics.listener == nil {
|
||||
@@ -220,13 +236,19 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
idpMetrics: idpMetrics,
|
||||
httpMiddleware: middleware,
|
||||
grpcMetrics: grpcMetrics,
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
idpMetrics: idpMetrics,
|
||||
httpMiddleware: middleware,
|
||||
grpcMetrics: grpcMetrics,
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -211,7 +211,7 @@ func NewOwnerUser(id string) *User {
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -267,7 +267,7 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user
|
||||
|
||||
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
||||
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
@@ -368,7 +368,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
||||
@@ -401,7 +401,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
||||
// ListUsers returns lists of all users under the account.
|
||||
// It doesn't populate user information such as email or name.
|
||||
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -428,7 +428,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
if initiatorUserID == targetUserID {
|
||||
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
|
||||
}
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -538,7 +538,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||
func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
@@ -578,7 +578,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if tokenName == "" {
|
||||
@@ -628,7 +628,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
|
||||
// DeletePAT deletes a specific PAT from a user
|
||||
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -678,7 +678,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
||||
|
||||
// GetPAT returns a specific PAT from a user
|
||||
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -710,7 +710,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@@ -752,7 +752,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
|
||||
@@ -859,7 +859,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveUsers(account.Id, account.Users); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -35,8 +35,11 @@ func InitLog(logLevel string, logPath string) error {
|
||||
AddSyslogHook()
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
if os.Getenv("NB_LOG_FORMAT") == "json" {
|
||||
formatter.SetJSONFormatter(log.StandardLogger())
|
||||
} else if logPath == "syslog" {
|
||||
formatter.SetSyslogFormatter(log.StandardLogger())
|
||||
} else {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user