Compare commits

...

31 Commits

Author SHA1 Message Date
Pascal Fischer
d063fbb8b9 [management] merge update account peers in sync call (#2978) 2024-12-03 16:41:19 +01:00
Viktor Liu
e5d42bc963 [client] Add state handling cmdline options (#2821) 2024-12-03 16:07:18 +01:00
Viktor Liu
8866394eb6 [client] Don't choke on non-existent interface in route updates (#2922) 2024-12-03 15:33:41 +01:00
Viktor Liu
17c20b45ce [client] Add network map to debug bundle (#2966) 2024-12-03 14:50:12 +01:00
Joakim Nohlgård
7dacd9cb23 [management] Add missing parentheses on iphone hostname generation condition (#2977) 2024-12-03 13:49:02 +01:00
Viktor Liu
6285e0d23e [client] Add netbird.err and netbird.out to debug bundle (#2971) 2024-12-03 12:43:17 +01:00
Maycon Santos
a4826cfb5f [client] Get static system info once (#2965)
Get static system info once for Windows, Darwin, and Linux nodes

This should improve startup and peer authentication times
2024-12-03 10:22:04 +01:00
Zoltan Papp
a0bf0bdcc0 Pass IP instead of net to Rosenpass (#2975) 2024-12-03 10:13:27 +01:00
Viktor Liu
dffce78a8c [client] Fix debug bundle state anonymization test (#2976) 2024-12-02 20:19:34 +01:00
Viktor Liu
c7e7ad5030 [client] Add state file to debug bundle (#2969) 2024-12-02 18:04:02 +01:00
Viktor Liu
5142dc52c1 [client] Persist route selection (#2810) 2024-12-02 17:55:02 +01:00
Zoltan Papp
ecb44ff306 [client] Add pprof build tag (#2964)
* Add pprof build tag

* Change env handling
2024-12-01 19:22:52 +01:00
victorserbu2709
e4a5fb3e91 Unspecified address: default NetworkTypeUDP4+NetworkTypeUDP6 (#2804) 2024-11-30 10:34:52 +01:00
v1rusnl
e52d352a48 Update Caddyfile and Docker Compose to support HTTP3 (#2822) 2024-11-30 10:26:31 +01:00
Maycon Santos
f9723c9266 [client] Account different policiy rules for routes firewall rules (#2939)
* Account different policies rules for routes firewall rules

This change ensures that route firewall rules will consider source group peers in the rules generation for access control policies.

This fixes the behavior where multiple policies with different levels of access was being applied to all peers in a distribution group

* split function

* avoid unnecessary allocation

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2024-11-29 17:50:35 +01:00
Maycon Santos
8efad1d170 Add guide when signing key is not found (#2942)
Some users face issues with their IdP due to signing key not being refreshed

With this change we advise users to configure key refresh

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* removing leftover

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2024-11-29 10:06:40 +01:00
Pascal Fischer
c6641be94b [tests] Enable benchmark tests on github actions (#2961) 2024-11-28 19:22:01 +01:00
Pascal Fischer
89cf8a55e2 [management] Add performance test for login and sync calls (#2960) 2024-11-28 14:59:53 +01:00
Pascal Fischer
00c3b67182 [management] refactor to use account object instead of separate db calls for peer update (#2957) 2024-11-28 11:13:01 +01:00
Zoltan Papp
9203690033 [client] Code cleaning in net pkg and fix exit node feature on Android(#2932)
Code cleaning around the util/net package. The goal was to write a more understandable source code but modify nothing on the logic.
Protect the WireGuard UDP listeners with marks.
The implementation can support the VPN permission revocation events in thread safe way. It will be important if we start to support the running time route and DNS update features.

- uniformize the file name convention: [struct_name] _ [functions] _ [os].go
- code cleaning in net_linux.go
- move env variables to env.go file
2024-11-26 23:34:27 +01:00
Bethuel Mmbaga
9683da54b0 [management] Refactor nameserver groups to use store methods (#2888) 2024-11-26 17:39:04 +01:00
Bethuel Mmbaga
0e48a772ff [management] Refactor DNS settings to use store methods (#2883) 2024-11-26 13:43:05 +01:00
Bethuel Mmbaga
f118d81d32 [management] Refactor policy to use store methods (#2878) 2024-11-26 10:46:05 +01:00
Bethuel Mmbaga
ca12bc6953 [management] Refactor posture check to use store methods (#2874) 2024-11-25 16:26:24 +01:00
Viktor Liu
9810386937 [client] Allow routing to fallback to exclusion routes if rules are not supported (#2909) 2024-11-25 15:19:56 +01:00
Viktor Liu
f1625b32bd [client] Set up sysctl and routing table name only if routing rules are available (#2933) 2024-11-25 15:12:16 +01:00
Viktor Liu
0ecd5f2118 [client] Test nftables for incompatible iptables rules (#2948) 2024-11-25 15:11:56 +01:00
Viktor Liu
940d0c48c6 [client] Don't return error in userspace mode without firewall (#2924) 2024-11-25 15:11:31 +01:00
Maycon Santos
56cecf849e Import time package (#2940) 2024-11-22 20:40:30 +01:00
Maycon Santos
05c4aa7c2c [misc] Renew slack link (#2938) 2024-11-22 18:50:47 +01:00
Zoltan Papp
2a5cb16494 [relay] Refactor initial Relay connection (#2800)
Can support firewalls with restricted WS rules

allow to run engine without Relay servers
keep up to date Relay address changes
2024-11-22 18:12:34 +01:00
98 changed files with 5664 additions and 1395 deletions

View File

@@ -52,6 +52,47 @@ jobs:
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
benchmark:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:

View File

@@ -17,7 +17,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br> <br>
@@ -34,7 +34,7 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a> Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
<br/> <br/>
</strong> </strong>

View File

@@ -12,6 +12,8 @@ import (
"strings" "strings"
) )
const anonTLD = ".domain"
type Anonymizer struct { type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string domainAnonymizer map[string]string
@@ -83,29 +85,39 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
} }
func (a *Anonymizer) AnonymizeDomain(domain string) string { func (a *Anonymizer) AnonymizeDomain(domain string) string {
if strings.HasSuffix(domain, "netbird.io") || baseDomain := domain
strings.HasSuffix(domain, "netbird.selfhosted") || hasDot := strings.HasSuffix(domain, ".")
strings.HasSuffix(domain, "netbird.cloud") || if hasDot {
strings.HasSuffix(domain, "netbird.stage") || baseDomain = domain[:len(domain)-1]
strings.HasSuffix(domain, ".domain") { }
if strings.HasSuffix(baseDomain, "netbird.io") ||
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
strings.HasSuffix(baseDomain, "netbird.cloud") ||
strings.HasSuffix(baseDomain, "netbird.stage") ||
strings.HasSuffix(baseDomain, anonTLD) {
return domain return domain
} }
parts := strings.Split(domain, ".") parts := strings.Split(baseDomain, ".")
if len(parts) < 2 { if len(parts) < 2 {
return domain return domain
} }
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1] baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseDomain] anonymized, ok := a.domainAnonymizer[baseForLookup]
if !ok { if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + ".domain" anonymizedBase := "anon-" + generateRandomString(5) + anonTLD
a.domainAnonymizer[baseDomain] = anonymizedBase a.domainAnonymizer[baseForLookup] = anonymizedBase
anonymized = anonymizedBase anonymized = anonymizedBase
} }
return strings.Replace(domain, baseDomain, anonymized, 1) result := strings.Replace(baseDomain, baseForLookup, anonymized, 1)
if hasDot {
result += "."
}
return result
} }
func (a *Anonymizer) AnonymizeURI(uri string) string { func (a *Anonymizer) AnonymizeURI(uri string) string {
@@ -152,9 +164,9 @@ func (a *Anonymizer) AnonymizeString(str string) string {
return str return str
} }
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes. // AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string { func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`) re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI) return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
} }
@@ -168,10 +180,10 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
parts := strings.Split(match, `"`) parts := strings.Split(match, `"`)
if len(parts) >= 2 { if len(parts) >= 2 {
domain := parts[1] domain := parts[1]
if strings.HasSuffix(domain, ".domain") { if strings.HasSuffix(domain, anonTLD) {
return match return match
} }
randomDomain := generateRandomString(10) + ".domain" randomDomain := generateRandomString(10) + anonTLD
return strings.Replace(match, domain, randomDomain, 1) return strings.Replace(match, domain, randomDomain, 1)
} }
return match return match

View File

@@ -67,18 +67,36 @@ func TestAnonymizeDomain(t *testing.T) {
`^anon-[a-zA-Z0-9]+\.domain$`, `^anon-[a-zA-Z0-9]+\.domain$`,
true, true,
}, },
{
"Domain with Trailing Dot",
"example.com.",
`^anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{ {
"Subdomain", "Subdomain",
"sub.example.com", "sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`, `^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true, true,
}, },
{
"Subdomain with Trailing Dot",
"sub.example.com.",
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{ {
"Protected Domain", "Protected Domain",
"netbird.io", "netbird.io",
`^netbird\.io$`, `^netbird\.io$`,
false, false,
}, },
{
"Protected Domain with Trailing Dot",
"netbird.io.",
`^netbird\.io.$`,
false,
},
} }
for _, tc := range tests { for _, tc := range tests {
@@ -140,8 +158,16 @@ func TestAnonymizeSchemeURI(t *testing.T) {
expect string expect string
}{ }{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`}, {"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`}, {"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, {"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
} }
for _, tc := range tests { for _, tc := range tests {

View File

@@ -3,6 +3,7 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -61,6 +62,15 @@ var forCmd = &cobra.Command{
RunE: runForDuration, RunE: runForDuration,
} }
var persistenceCmd = &cobra.Command{
Use: "persistence [on|off]",
Short: "Set network map memory persistence",
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
Example: " netbird debug persistence on",
Args: cobra.ExactArgs(1),
RunE: setNetworkMapPersistence,
}
func debugBundle(cmd *cobra.Command, _ []string) error { func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd)
if err != nil { if err != nil {
@@ -171,6 +181,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
// Enable network map persistence before bringing the service up
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
} }
@@ -200,6 +217,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown { if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
@@ -219,6 +243,34 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
persistence := strings.ToLower(args[0])
if persistence != "on" && persistence != "off" {
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
}
client := proto.NewDaemonServiceClient(conn)
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: persistence == "on",
})
if err != nil {
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message())
}
cmd.Printf("Network map persistence set to: %s\n", persistence)
return nil
}
func getStatusOutput(cmd *cobra.Command) string { func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string var statusOutputString string
statusResp, err := getStatus(cmd.Context()) statusResp, err := getStatus(cmd.Context())

33
client/cmd/pprof.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

View File

@@ -155,6 +155,7 @@ func init() {
debugCmd.AddCommand(logCmd) debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd) logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd) debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+ `Sets external IPs maps between local addresses and interfaces.`+

181
client/cmd/state.go Normal file
View File

@@ -0,0 +1,181 @@
package cmd
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var (
allFlag bool
)
var stateCmd = &cobra.Command{
Use: "state",
Short: "Manage daemon state",
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
}
var stateListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all stored states",
Long: "Lists all registered states with their status and basic information.",
Example: " netbird state list",
RunE: stateList,
}
var stateCleanCmd = &cobra.Command{
Use: "clean [state-name]",
Short: "Clean stored states",
Long: `Clean specific state or all states. The daemon must not be running.
This will perform cleanup operations and remove the state.`,
Example: ` netbird state clean dns_state
netbird state clean --all`,
RunE: stateClean,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
var stateDeleteCmd = &cobra.Command{
Use: "delete [state-name]",
Short: "Delete stored states",
Long: `Delete specific state or all states from storage. The daemon must not be running.
This will remove the state without performing any cleanup operations.`,
Example: ` netbird state delete dns_state
netbird state delete --all`,
RunE: stateDelete,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
func init() {
rootCmd.AddCommand(stateCmd)
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
}
func stateList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
if err != nil {
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
}
cmd.Printf("\nStored states:\n\n")
for _, state := range resp.States {
cmd.Printf("- %s\n", state.Name)
}
return nil
}
func stateClean(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
}
if resp.CleanedStates == 0 {
cmd.Println("No states were cleaned")
return nil
}
if allFlag {
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
} else {
cmd.Printf("Successfully cleaned state %q\n", stateName)
}
return nil
}
func stateDelete(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
}
if resp.DeletedStates == 0 {
cmd.Println("No states were deleted")
return nil
}
if allFlag {
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
} else {
cmd.Printf("Successfully deleted state %q\n", stateName)
}
return nil
}

View File

@@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
return err return err
} }
s.ips = temp.IPs s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil return nil
} }
@@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
return err return err
} }
s.ipsets = temp.IPSets s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil return nil
} }

View File

@@ -1,9 +1,11 @@
package nftables package nftables
import ( import (
"bytes"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os/exec"
"testing" "testing"
"time" "time"
@@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}) })
} }
} }
func runIptablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("iptables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
require.NoError(t, err, "iptables-save failed to run")
return stdout.String(), stderr.String()
}
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
// Check for any incompatibility warnings
require.NotContains(t,
stderr,
"incompatible",
"iptables-save produced compatibility warning. Full stderr: %s",
stderr,
)
// Verify standard tables are present
expectedTables := []string{
"*filter",
"*nat",
"*mangle",
}
for _, table := range expectedTables {
require.Contains(t,
stdout,
table,
"iptables-save output missing expected table: %s\nFull stdout: %s",
table,
stdout,
)
}
}
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Reset(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,
}
err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}

View File

@@ -1 +0,0 @@
package nftables

View File

@@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
// SetLegacyManagement doesn't need to be implemented for this manager // SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return nil
} }
return m.nativeFirewall.SetLegacyManagement(isLegacy) return m.nativeFirewall.SetLegacyManagement(isLegacy)
} }

View File

@@ -0,0 +1,12 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -162,12 +162,13 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType var networks []ice.NetworkType
switch { switch {
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
case addr.IP.To16() != nil: case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default: default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
} }

View File

@@ -40,6 +40,8 @@ type ConnectClient struct {
statusRecorder *peer.Status statusRecorder *peer.Status
engine *Engine engine *Engine
engineMutex sync.Mutex engineMutex sync.Mutex
persistNetworkMap bool
} }
func NewConnectClient( func NewConnectClient(
@@ -232,6 +234,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
relayURLs, token := parseRelayInfo(loginResp) relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 { if len(relayURLs) > 0 {
if token != nil { if token != nil {
if err := relayManager.UpdateToken(token); err != nil { if err := relayManager.UpdateToken(token); err != nil {
@@ -242,9 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", ")) log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
if err = relayManager.Serve(); err != nil { if err = relayManager.Serve(); err != nil {
log.Error(err) log.Error(err)
return wrapErr(err)
} }
c.statusRecorder.SetRelayMgr(relayManager)
} }
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
@@ -259,7 +260,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.engineMutex.Lock() c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
c.engineMutex.Unlock() c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil { if err := c.engine.Start(); err != nil {
@@ -337,6 +338,19 @@ func (c *ConnectClient) Engine() *Engine {
return e return e
} }
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {
return StatusIdle
}
status, err := CtxGetState(c.ctx).Status()
if err != nil {
return StatusIdle
}
return status
}
func (c *ConnectClient) Stop() error { func (c *ConnectClient) Stop() error {
if c == nil { if c == nil {
return nil return nil
@@ -363,6 +377,22 @@ func (c *ConnectClient) isContextCancelled() bool {
} }
} }
// SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored
// network map will be cleared. This functionality is primarily used for debugging
// and should not be enabled during normal operation.
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
c.engineMutex.Lock()
c.persistNetworkMap = enabled
c.engineMutex.Unlock()
engine := c.Engine()
if engine != nil {
engine.SetNetworkMapPersistence(enabled)
}
}
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false nm := false

View File

@@ -21,6 +21,7 @@ import (
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/proto"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
@@ -172,6 +173,10 @@ type Engine struct {
relayManager *relayClient.Manager relayManager *relayClient.Manager
stateManager *statemanager.Manager stateManager *statemanager.Manager
srWatcher *guard.SRWatcher srWatcher *guard.SRWatcher
// Network map persistence
persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -349,8 +354,17 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) e.routeManager = routemanager.NewManager(
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) e.ctx,
e.config.WgPrivateKey.PublicKey().String(),
e.config.DNSRouteInterval,
e.wgInterface,
e.statusRecorder,
e.relayManager,
initialRoutes,
e.stateManager,
)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
} else { } else {
@@ -538,6 +552,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
relayMsg := wCfg.GetRelay() relayMsg := wCfg.GetRelay()
if relayMsg != nil { if relayMsg != nil {
// when we receive token we expect valid address list too
c := &auth.Token{ c := &auth.Token{
Payload: relayMsg.GetTokenPayload(), Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(), Signature: relayMsg.GetTokenSignature(),
@@ -546,9 +561,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
log.Errorf("failed to update relay token: %v", err) log.Errorf("failed to update relay token: %v", err)
return fmt.Errorf("update relay token: %w", err) return fmt.Errorf("update relay token: %w", err)
} }
e.relayManager.UpdateServerURLs(relayMsg.Urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.
_ = e.relayManager.Serve()
} else {
e.relayManager.UpdateServerURLs(nil)
} }
// todo update relay address in the relay manager
// todo update signal // todo update signal
} }
@@ -556,13 +578,22 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err return err
} }
if update.GetNetworkMap() != nil { nm := update.GetNetworkMap()
// only apply new changes and ignore old ones if nm == nil {
err := e.updateNetworkMap(update.GetNetworkMap()) return nil
if err != nil {
return err
}
} }
// Store network map if persistence is enabled
if e.persistNetworkMap {
e.latestNetworkMap = nm
log.Debugf("network map persisted with serial %d", nm.GetSerial())
}
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
return err
}
return nil return nil
} }
@@ -1483,6 +1514,46 @@ func (e *Engine) stopDNSServer() {
e.statusRecorder.UpdateDNSStates(nsGroupStates) e.statusRecorder.UpdateDNSStates(nsGroupStates)
} }
// SetNetworkMapPersistence enables or disables network map persistence
func (e *Engine) SetNetworkMapPersistence(enabled bool) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if enabled == e.persistNetworkMap {
return
}
e.persistNetworkMap = enabled
log.Debugf("Network map persistence is set to %t", enabled)
if !enabled {
e.latestNetworkMap = nil
}
}
// GetLatestNetworkMap returns the stored network map if persistence is enabled
func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if !e.persistNetworkMap {
return nil, errors.New("network map persistence is disabled")
}
if e.latestNetworkMap == nil {
//nolint:nilnil
return nil, nil
}
// Create a deep copy to avoid external modifications
nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap)
if !ok {
return nil, fmt.Errorf("failed to clone network map")
}
return nm, nil
}
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks { for _, check := range checks {

View File

@@ -245,12 +245,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
nil) nil)
wgIface := &iface.MockWGIface{ wgIface := &iface.MockWGIface{
NameFunc: func() string { return "utun102" },
RemovePeerFunc: func(peerKey string) error { RemovePeerFunc: func(peerKey string) error {
return nil return nil
}, },
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
_, _, err = engine.routeManager.Init()
require.NoError(t, err)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }

View File

@@ -83,7 +83,6 @@ type Conn struct {
signaler *Signaler signaler *Signaler
relayManager *relayClient.Manager relayManager *relayClient.Manager
allowedIP net.IP allowedIP net.IP
allowedNet string
handshaker *Handshaker handshaker *Handshaker
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
@@ -111,7 +110,7 @@ type Conn struct {
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil { if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err) log.Errorf("failed to parse allowedIPS: %v", err)
return nil, err return nil, err
@@ -129,7 +128,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
signaler: signaler, signaler: signaler,
relayManager: relayManager, relayManager: relayManager,
allowedIP: allowedIP, allowedIP: allowedIP,
allowedNet: allowedNet.String(),
statusRelay: NewAtomicConnStatus(), statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(),
} }
@@ -594,7 +592,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
} }
if conn.onConnected != nil { if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr) conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr)
} }
} }

View File

@@ -676,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
// extend the list of stun, turn servers with relay address // extend the list of stun, turn servers with relay address
relayStates := slices.Clone(d.relayStates) relayStates := slices.Clone(d.relayStates)
var relayState relay.ProbeResult
// if the server connection is not established then we will use the general address // if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address // in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress() instanceAddr, err := d.relayMgr.RelayInstanceAddress()
if err != nil { if err != nil {
// TODO add their status // TODO add their status
if errors.Is(err, relayClient.ErrRelayClientNotConnected) { for _, r := range d.relayMgr.ServerURLs() {
for _, r := range d.relayMgr.ServerURLs() { relayStates = append(relayStates, relay.ProbeResult{
relayStates = append(relayStates, relay.ProbeResult{ URI: r,
URI: r, Err: err,
}) })
}
return relayStates
} }
relayState.Err = err return relayStates
} }
relayState.URI = instanceAddr relayState := relay.ProbeResult{
URI: instanceAddr,
}
return append(relayStates, relayState) return append(relayStates, relayState)
} }

View File

@@ -46,8 +46,6 @@ type WorkerICE struct {
hasRelayOnLocally bool hasRelayOnLocally bool
conn WorkerICECallbacks conn WorkerICECallbacks
selectedPriority ConnPriority
agent *ice.Agent agent *ice.Agent
muxAgent sync.Mutex muxAgent sync.Mutex
@@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = icemaker.CandidateTypesP2P() preferredCandidateTypes = icemaker.CandidateTypesP2P()
} else { } else {
w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = icemaker.CandidateTypes() preferredCandidateTypes = icemaker.CandidateTypes()
} }
@@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RelayedOnLocal: isRelayCandidate(pair.Local), RelayedOnLocal: isRelayCandidate(pair.Local),
} }
w.log.Debugf("on ICE conn read to use ready") w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci) go w.conn.OnConnReady(selectedPriority(pair), ci)
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -394,3 +390,11 @@ func isRelayed(pair *ice.CandidatePair) bool {
} }
return false return false
} }
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
if isRelayed(pair) {
return connPriorityICETurn
} else {
return connPriorityICEP2P
}
}

View File

@@ -32,7 +32,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
@@ -59,6 +59,7 @@ type DefaultManager struct {
routeRefCounter *refcounter.RouteRefCounter routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
stateManager *statemanager.Manager
} }
func NewManager( func NewManager(
@@ -69,6 +70,7 @@ func NewManager(
statusRecorder *peer.Status, statusRecorder *peer.Status,
relayMgr *relayClient.Manager, relayMgr *relayClient.Manager,
initialRoutes []*route.Route, initialRoutes []*route.Route,
stateManager *statemanager.Manager,
) *DefaultManager { ) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
@@ -80,12 +82,12 @@ func NewManager(
dnsRouteInterval: dnsRouteInterval, dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork), clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
relayMgr: relayMgr, relayMgr: relayMgr,
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps, sysOps: sysOps,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
notifier: notifier, notifier: notifier,
stateManager: stateManager,
} }
dm.routeRefCounter = refcounter.New( dm.routeRefCounter = refcounter.New(
@@ -121,7 +123,7 @@ func NewManager(
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
if nbnet.CustomRoutingDisabled() { if nbnet.CustomRoutingDisabled() {
return nil, nil, nil return nil, nil, nil
} }
@@ -137,14 +139,38 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook
ips := resolveURLsToIPs(initialAddresses) ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err) return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
m.routeSelector = m.initSelector()
log.Info("Routing setup complete") log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil return beforePeerHook, afterPeerHook, nil
} }
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
var state *SelectorState
m.stateManager.RegisterState(state)
// restore selector state if it exists
if err := m.stateManager.LoadState(state); err != nil {
log.Warnf("failed to load state: %v", err)
return routeselector.NewRouteSelector()
}
if state := m.stateManager.GetState(state); state != nil {
if selector, ok := state.(*SelectorState); ok {
return (*routeselector.RouteSelector)(selector)
}
log.Warnf("failed to convert state with type %T to SelectorState", state)
}
return routeselector.NewRouteSelector()
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
@@ -252,6 +278,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
} }
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
log.Errorf("failed to update state: %v", err)
}
} }
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list // stopObsoleteClients stops the client network watcher for the networks that are not in the new list

View File

@@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm") statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
_, _, err = routeManager.Init(nil) _, _, err = routeManager.Init()
require.NoError(t, err, "should init route manager") require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil) defer routeManager.Stop(nil)

View File

@@ -21,7 +21,7 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager) StopFunc func(manager *statemanager.Manager)
} }
func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil return nil, nil, nil
} }

View File

@@ -71,11 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
} }
// LoadData loads the data from the existing counter // LoadData loads the data from the existing counter
// The passed counter should not be used any longer after calling this function.
func (rm *Counter[Key, I, O]) LoadData( func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O], existingCounter *Counter[Key, I, O],
) { ) {
rm.mu.Lock() rm.mu.Lock()
defer rm.mu.Unlock() defer rm.mu.Unlock()
existingCounter.mu.Lock()
defer existingCounter.mu.Unlock()
rm.refCountMap = existingCounter.refCountMap rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap rm.idMap = existingCounter.idMap
@@ -231,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements the json.Unmarshaler interface for Counter. // UnmarshalJSON implements the json.Unmarshaler interface for Counter.
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.mu.Lock()
defer rm.mu.Unlock()
var temp struct { var temp struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"` RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"` IDMap map[string][]Key `json:"idMap"`
@@ -241,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.refCountMap = temp.RefCountMap rm.refCountMap = temp.RefCountMap
rm.idMap = temp.IDMap rm.idMap = temp.IDMap
if temp.RefCountMap == nil {
temp.RefCountMap = map[Key]Ref[O]{}
}
if temp.IDMap == nil {
temp.IDMap = map[string][]Key{}
}
return nil return nil
} }

View File

@@ -0,0 +1,19 @@
package routemanager
import (
"github.com/netbirdio/netbird/client/internal/routeselector"
)
type SelectorState routeselector.RouteSelector
func (s *SelectorState) Name() string {
return "routeselector_state"
}
func (s *SelectorState) MarshalJSON() ([]byte, error) {
return (*routeselector.RouteSelector)(s).MarshalJSON()
}
func (s *SelectorState) UnmarshalJSON(data []byte) error {
return (*routeselector.RouteSelector)(s).UnmarshalJSON(data)
}

View File

@@ -55,7 +55,7 @@ type ruleParams struct {
// isLegacy determines whether to use the legacy routing setup // isLegacy determines whether to use the legacy routing setup
func isLegacy() bool { func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true" return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
} }
// setIsLegacy sets the legacy routing setup // setIsLegacy sets the legacy routing setup
@@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() { defer func() {
if err != nil { if err != nil {
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
@@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
} }
} }
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
return nil, nil, nil return nil, nil, nil
} }
@@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) {
return fmt.Errorf("add routing rule: %w", err) return fmt.Errorf("add routing rule: %w", err)
} }
@@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) {
return fmt.Errorf("remove routing rule: %w", err) return fmt.Errorf("remove routing rule: %w", err)
} }

View File

@@ -230,10 +230,13 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
if idx != 0 { if idx != 0 {
intf, err := net.InterfaceByIndex(idx) intf, err := net.InterfaceByIndex(idx)
if err != nil { if err != nil {
return update, fmt.Errorf("get interface name: %w", err) log.Warnf("failed to get interface name for index %d: %v", idx, err)
update.Interface = &net.Interface{
Index: idx,
}
} else {
update.Interface = intf
} }
update.Interface = intf
} }
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface)

View File

@@ -1,8 +1,10 @@
package routeselector package routeselector
import ( import (
"encoding/json"
"fmt" "fmt"
"slices" "slices"
"sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -12,6 +14,7 @@ import (
) )
type RouteSelector struct { type RouteSelector struct {
mu sync.RWMutex
selectedRoutes map[route.NetID]struct{} selectedRoutes map[route.NetID]struct{}
selectAll bool selectAll bool
} }
@@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector {
// SelectRoutes updates the selected routes based on the provided route IDs. // SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error { func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if !appendRoute { if !appendRoute {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
@@ -46,6 +52,9 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
// SelectAllRoutes sets the selector to select all routes. // SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() { func (rs *RouteSelector) SelectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.selectAll = true rs.selectAll = true
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
@@ -53,6 +62,9 @@ func (rs *RouteSelector) SelectAllRoutes() {
// DeselectRoutes removes specific routes from the selection. // DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode. // If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.selectAll { if rs.selectAll {
rs.selectAll = false rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
@@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
// DeselectAllRoutes deselects all routes, effectively disabling route selection. // DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() { func (rs *RouteSelector) DeselectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.selectAll = false rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
// IsSelected checks if a specific route is selected. // IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.selectAll { if rs.selectAll {
return true return true
} }
@@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
// FilterSelected removes unselected routes from the provided map. // FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.selectAll { if rs.selectAll {
return maps.Clone(routes) return maps.Clone(routes)
} }
@@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
} }
return filtered return filtered
} }
// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
defer rs.mu.RUnlock()
return json.Marshal(struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}{
SelectAll: rs.selectAll,
SelectedRoutes: rs.selectedRoutes,
})
}
// UnmarshalJSON implements the json.Unmarshaler interface
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
rs.mu.Lock()
defer rs.mu.Unlock()
// Check for null or empty JSON
if len(data) == 0 || string(data) == "null" {
rs.selectedRoutes = map[route.NetID]struct{}{}
rs.selectAll = true
return nil
}
var temp struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
rs.selectedRoutes = temp.SelectedRoutes
rs.selectAll = temp.SelectAll
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
return nil
}

View File

@@ -19,12 +19,36 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
const (
errStateNotRegistered = "state %s not registered"
errLoadStateFile = "load state file: %w"
)
// State interface defines the methods that all state types must implement // State interface defines the methods that all state types must implement
type State interface { type State interface {
Name() string Name() string
}
// CleanableState interface extends State with cleanup capability
type CleanableState interface {
State
Cleanup() error Cleanup() error
} }
// RawState wraps raw JSON data for unregistered states
type RawState struct {
data json.RawMessage
}
func (r *RawState) Name() string {
return "" // This is a placeholder implementation
}
// MarshalJSON implements json.Marshaler to preserve the original JSON
func (r *RawState) MarshalJSON() ([]byte, error) {
return r.data, nil
}
// Manager handles the persistence and management of various states // Manager handles the persistence and management of various states
type Manager struct { type Manager struct {
mu sync.Mutex mu sync.Mutex
@@ -140,7 +164,7 @@ func (m *Manager) setState(name string, state State) error {
defer m.mu.Unlock() defer m.mu.Unlock()
if _, exists := m.states[name]; !exists { if _, exists := m.states[name]; !exists {
return fmt.Errorf("state %s not registered", name) return fmt.Errorf(errStateNotRegistered, name)
} }
m.states[name] = state m.states[name] = state
@@ -149,6 +173,63 @@ func (m *Manager) setState(name string, state State) error {
return nil return nil
} }
// DeleteStateByName handles deletion of states without cleanup.
// It doesn't require the state to be registered.
func (m *Manager) DeleteStateByName(stateName string) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil
}
if _, exists := rawStates[stateName]; !exists {
return fmt.Errorf("state %s not found", stateName)
}
// Mark state as deleted by setting it to nil and marking it dirty
m.states[stateName] = nil
m.dirty[stateName] = struct{}{}
return nil
}
// DeleteAllStates removes all states.
func (m *Manager) DeleteAllStates() (int, error) {
if m == nil {
return 0, nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return 0, fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return 0, nil
}
count := len(rawStates)
// Mark all states as deleted and dirty
for name := range rawStates {
m.states[name] = nil
m.dirty[name] = struct{}{}
}
return count, nil
}
func (m *Manager) periodicStateSave(ctx context.Context) { func (m *Manager) periodicStateSave(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second) ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -202,63 +283,175 @@ func (m *Manager) PersistState(ctx context.Context) error {
} }
} }
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start)) log.Debugf("persisted states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
clear(m.dirty) clear(m.dirty)
return nil return nil
} }
// loadState loads the existing state from the state file // loadStateFile reads and unmarshals the state file into a map of raw JSON messages
func (m *Manager) loadState() error { func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) {
data, err := os.ReadFile(m.filePath) data, err := os.ReadFile(m.filePath)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist") log.Debug("state file does not exist")
return nil return nil, nil // nolint:nilnil
} }
return fmt.Errorf("read state file: %w", err) return nil, fmt.Errorf("read state file: %w", err)
} }
var rawStates map[string]json.RawMessage var rawStates map[string]json.RawMessage
if err := json.Unmarshal(data, &rawStates); err != nil { if err := json.Unmarshal(data, &rawStates); err != nil {
log.Warn("State file appears to be corrupted, attempting to delete it") if deleteCorrupt {
if err := os.Remove(m.filePath); err != nil { log.Warn("State file appears to be corrupted, attempting to delete it", err)
log.Errorf("Failed to delete corrupted state file: %v", err) if err := os.Remove(m.filePath); err != nil {
} else { log.Errorf("Failed to delete corrupted state file: %v", err)
log.Info("State file deleted") } else {
log.Info("State file deleted")
}
} }
return fmt.Errorf("unmarshal states: %w", err) return nil, fmt.Errorf("unmarshal states: %w", err)
} }
var merr *multierror.Error return rawStates, nil
}
for name, rawState := range rawStates { // loadSingleRawState unmarshals a raw state into a concrete state object
stateType, ok := m.stateTypes[name] func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
if !ok { stateType, ok := m.stateTypes[name]
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) if !ok {
continue return nil, fmt.Errorf(errStateNotRegistered, name)
} }
if string(rawState) == "null" { if string(rawState) == "null" {
continue return nil, nil //nolint:nilnil
} }
statePtr := reflect.New(stateType).Interface().(State) statePtr := reflect.New(stateType).Interface().(State)
if err := json.Unmarshal(rawState, statePtr); err != nil { if err := json.Unmarshal(rawState, statePtr); err != nil {
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
continue }
}
m.states[name] = statePtr return statePtr, nil
}
// LoadState loads a specific state from the state file
func (m *Manager) LoadState(state State) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return err
}
if rawStates == nil {
return nil
}
name := state.Name()
rawState, exists := rawStates[name]
if !exists {
return nil
}
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
return err
}
m.states[name] = loadedState
if loadedState != nil {
log.Debugf("loaded state: %s", name) log.Debugf("loaded state: %s", name)
} }
return nberrors.FormatErrorOrNil(merr) return nil
} }
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. // cleanupSingleState handles the cleanup of a specific state and returns any error.
// If the cleanup is successful, the state is marked for deletion. // The caller must hold the mutex.
func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) error {
// For unregistered states, preserve the raw JSON
if _, registered := m.stateTypes[name]; !registered {
m.states[name] = &RawState{data: rawState}
return nil
}
// Load the state
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
return err
}
if loadedState == nil {
return nil
}
// Check if state supports cleanup
cleanableState, isCleanable := loadedState.(CleanableState)
if !isCleanable {
// If it doesn't support cleanup, keep it as-is
m.states[name] = loadedState
return nil
}
// Perform cleanup
log.Infof("cleaning up state %s", name)
if err := cleanableState.Cleanup(); err != nil {
// On cleanup error, preserve the state
m.states[name] = loadedState
return fmt.Errorf("cleanup state: %w", err)
}
// Successfully cleaned up - mark for deletion
m.states[name] = nil
m.dirty[name] = struct{}{}
return nil
}
// CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState.
// Returns an error if the state doesn't exist, isn't registered, or cleanup fails.
func (m *Manager) CleanupStateByName(name string) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
// Check if state is registered
if _, registered := m.stateTypes[name]; !registered {
return fmt.Errorf(errStateNotRegistered, name)
}
// Load raw states from file
rawStates, err := m.loadStateFile(false)
if err != nil {
return err
}
if rawStates == nil {
return nil
}
// Check if state exists in file
rawState, exists := rawStates[name]
if !exists {
return nil
}
if err := m.cleanupSingleState(name, rawState); err != nil {
return fmt.Errorf("%s: %w", name, err)
}
return nil
}
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
// Unregistered states are preserved in their original state.
func (m *Manager) PerformCleanup() error { func (m *Manager) PerformCleanup() error {
if m == nil { if m == nil {
return nil return nil
@@ -267,30 +460,51 @@ func (m *Manager) PerformCleanup() error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if err := m.loadState(); err != nil { // Load raw states from file
log.Warnf("Failed to load state during cleanup: %v", err) rawStates, err := m.loadStateFile(true)
if err != nil {
return fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil
} }
var merr *multierror.Error var merr *multierror.Error
for name, state := range m.states {
if state == nil {
// If no state was found in the state file, we don't mark the state dirty nor return an error
continue
}
log.Infof("client was not shut down properly, cleaning up %s", name) // Process each state in the file
if err := state.Cleanup(); err != nil { for name, rawState := range rawStates {
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) if err := m.cleanupSingleState(name, rawState); err != nil {
} else { merr = multierror.Append(merr, fmt.Errorf("%s: %w", name, err))
// mark for deletion on cleanup success
m.states[name] = nil
m.dirty[name] = struct{}{}
} }
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// GetSavedStateNames returns all state names that are currently saved in the state file.
func (m *Manager) GetSavedStateNames() ([]string, error) {
if m == nil {
return nil, nil
}
rawStates, err := m.loadStateFile(false)
if err != nil {
return nil, fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil, nil
}
var states []string
for name, state := range rawStates {
if len(state) != 0 && string(state) != "null" {
states = append(states, name)
}
}
return states, nil
}
func marshalWithPanicRecovery(v any) ([]byte, error) { func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte var bs []byte
var err error var err error

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v3.21.12 // protoc v4.23.4
// source: daemon.proto // source: daemon.proto
package proto package proto
@@ -2103,6 +2103,434 @@ func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{30} return file_daemon_proto_rawDescGZIP(), []int{30}
} }
// State represents a daemon state entry
type State struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
}
func (x *State) Reset() {
*x = State{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[31]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *State) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*State) ProtoMessage() {}
func (x *State) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[31]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use State.ProtoReflect.Descriptor instead.
func (*State) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{31}
}
func (x *State) GetName() string {
if x != nil {
return x.Name
}
return ""
}
// ListStatesRequest is empty as it requires no parameters
type ListStatesRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *ListStatesRequest) Reset() {
*x = ListStatesRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[32]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListStatesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListStatesRequest) ProtoMessage() {}
func (x *ListStatesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[32]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead.
func (*ListStatesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{32}
}
// ListStatesResponse contains a list of states
type ListStatesResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
States []*State `protobuf:"bytes,1,rep,name=states,proto3" json:"states,omitempty"`
}
func (x *ListStatesResponse) Reset() {
*x = ListStatesResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[33]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListStatesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListStatesResponse) ProtoMessage() {}
func (x *ListStatesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[33]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead.
func (*ListStatesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{33}
}
func (x *ListStatesResponse) GetStates() []*State {
if x != nil {
return x.States
}
return nil
}
// CleanStateRequest for cleaning states
type CleanStateRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"`
All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"`
}
func (x *CleanStateRequest) Reset() {
*x = CleanStateRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[34]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CleanStateRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CleanStateRequest) ProtoMessage() {}
func (x *CleanStateRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[34]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead.
func (*CleanStateRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{34}
}
func (x *CleanStateRequest) GetStateName() string {
if x != nil {
return x.StateName
}
return ""
}
func (x *CleanStateRequest) GetAll() bool {
if x != nil {
return x.All
}
return false
}
// CleanStateResponse contains the result of the clean operation
type CleanStateResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
CleanedStates int32 `protobuf:"varint,1,opt,name=cleaned_states,json=cleanedStates,proto3" json:"cleaned_states,omitempty"`
}
func (x *CleanStateResponse) Reset() {
*x = CleanStateResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[35]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CleanStateResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CleanStateResponse) ProtoMessage() {}
func (x *CleanStateResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[35]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead.
func (*CleanStateResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{35}
}
func (x *CleanStateResponse) GetCleanedStates() int32 {
if x != nil {
return x.CleanedStates
}
return 0
}
// DeleteStateRequest for deleting states
type DeleteStateRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"`
All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"`
}
func (x *DeleteStateRequest) Reset() {
*x = DeleteStateRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[36]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DeleteStateRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DeleteStateRequest) ProtoMessage() {}
func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[36]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead.
func (*DeleteStateRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{36}
}
func (x *DeleteStateRequest) GetStateName() string {
if x != nil {
return x.StateName
}
return ""
}
func (x *DeleteStateRequest) GetAll() bool {
if x != nil {
return x.All
}
return false
}
// DeleteStateResponse contains the result of the delete operation
type DeleteStateResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
DeletedStates int32 `protobuf:"varint,1,opt,name=deleted_states,json=deletedStates,proto3" json:"deleted_states,omitempty"`
}
func (x *DeleteStateResponse) Reset() {
*x = DeleteStateResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[37]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DeleteStateResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DeleteStateResponse) ProtoMessage() {}
func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[37]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead.
func (*DeleteStateResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{37}
}
func (x *DeleteStateResponse) GetDeletedStates() int32 {
if x != nil {
return x.DeletedStates
}
return 0
}
type SetNetworkMapPersistenceRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"`
}
func (x *SetNetworkMapPersistenceRequest) Reset() {
*x = SetNetworkMapPersistenceRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[38]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetNetworkMapPersistenceRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetNetworkMapPersistenceRequest) ProtoMessage() {}
func (x *SetNetworkMapPersistenceRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[38]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetNetworkMapPersistenceRequest.ProtoReflect.Descriptor instead.
func (*SetNetworkMapPersistenceRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{38}
}
func (x *SetNetworkMapPersistenceRequest) GetEnabled() bool {
if x != nil {
return x.Enabled
}
return false
}
type SetNetworkMapPersistenceResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SetNetworkMapPersistenceResponse) Reset() {
*x = SetNetworkMapPersistenceResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[39]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetNetworkMapPersistenceResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetNetworkMapPersistenceResponse) ProtoMessage() {}
func (x *SetNetworkMapPersistenceResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[39]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetNetworkMapPersistenceResponse.ProtoReflect.Descriptor instead.
func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{39}
}
var File_daemon_proto protoreflect.FileDescriptor var File_daemon_proto protoreflect.FileDescriptor
var file_daemon_proto_rawDesc = []byte{ var file_daemon_proto_rawDesc = []byte{
@@ -2399,66 +2827,116 @@ var file_daemon_proto_rawDesc = []byte{
0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74,
0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d,
0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a,
0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73,
0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74,
0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71,
0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e,
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08,
0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63,
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74,
0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74,
0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02,
0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c,
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65,
0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e,
0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61,
0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f,
0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65,
0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c,
0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05,
0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52,
0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04,
0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10,
0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x81, 0x09, 0x0a,
0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36,
0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53,
0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69,
0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a,
0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44,
0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66,
0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f,
0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69,
0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65,
0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74,
0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62,
0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c,
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65,
0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a,
0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53,
0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74,
0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45,
0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53,
0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65,
0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61,
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71,
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69,
0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x6f, 0x33,
0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47,
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c,
0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67,
0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42,
0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
} }
var ( var (
@@ -2474,50 +2952,59 @@ func file_daemon_proto_rawDescGZIP() []byte {
} }
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 32) var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41)
var file_daemon_proto_goTypes = []interface{}{ var file_daemon_proto_goTypes = []interface{}{
(LogLevel)(0), // 0: daemon.LogLevel (LogLevel)(0), // 0: daemon.LogLevel
(*LoginRequest)(nil), // 1: daemon.LoginRequest (*LoginRequest)(nil), // 1: daemon.LoginRequest
(*LoginResponse)(nil), // 2: daemon.LoginResponse (*LoginResponse)(nil), // 2: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest (*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse (*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 5: daemon.UpRequest (*UpRequest)(nil), // 5: daemon.UpRequest
(*UpResponse)(nil), // 6: daemon.UpResponse (*UpResponse)(nil), // 6: daemon.UpResponse
(*StatusRequest)(nil), // 7: daemon.StatusRequest (*StatusRequest)(nil), // 7: daemon.StatusRequest
(*StatusResponse)(nil), // 8: daemon.StatusResponse (*StatusResponse)(nil), // 8: daemon.StatusResponse
(*DownRequest)(nil), // 9: daemon.DownRequest (*DownRequest)(nil), // 9: daemon.DownRequest
(*DownResponse)(nil), // 10: daemon.DownResponse (*DownResponse)(nil), // 10: daemon.DownResponse
(*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest (*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse (*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
(*PeerState)(nil), // 13: daemon.PeerState (*PeerState)(nil), // 13: daemon.PeerState
(*LocalPeerState)(nil), // 14: daemon.LocalPeerState (*LocalPeerState)(nil), // 14: daemon.LocalPeerState
(*SignalState)(nil), // 15: daemon.SignalState (*SignalState)(nil), // 15: daemon.SignalState
(*ManagementState)(nil), // 16: daemon.ManagementState (*ManagementState)(nil), // 16: daemon.ManagementState
(*RelayState)(nil), // 17: daemon.RelayState (*RelayState)(nil), // 17: daemon.RelayState
(*NSGroupState)(nil), // 18: daemon.NSGroupState (*NSGroupState)(nil), // 18: daemon.NSGroupState
(*FullStatus)(nil), // 19: daemon.FullStatus (*FullStatus)(nil), // 19: daemon.FullStatus
(*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
(*IPList)(nil), // 24: daemon.IPList (*IPList)(nil), // 24: daemon.IPList
(*Route)(nil), // 25: daemon.Route (*Route)(nil), // 25: daemon.Route
(*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse
(*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest
(*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse (*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse
(*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest (*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse (*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse
nil, // 32: daemon.Route.ResolvedIPsEntry (*State)(nil), // 32: daemon.State
(*durationpb.Duration)(nil), // 33: google.protobuf.Duration (*ListStatesRequest)(nil), // 33: daemon.ListStatesRequest
(*timestamppb.Timestamp)(nil), // 34: google.protobuf.Timestamp (*ListStatesResponse)(nil), // 34: daemon.ListStatesResponse
(*CleanStateRequest)(nil), // 35: daemon.CleanStateRequest
(*CleanStateResponse)(nil), // 36: daemon.CleanStateResponse
(*DeleteStateRequest)(nil), // 37: daemon.DeleteStateRequest
(*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse
(*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest
(*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse
nil, // 41: daemon.Route.ResolvedIPsEntry
(*durationpb.Duration)(nil), // 42: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp
} }
var file_daemon_proto_depIdxs = []int32{ var file_daemon_proto_depIdxs = []int32{
33, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
34, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp 43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
34, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp 43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
33, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -2525,39 +3012,48 @@ var file_daemon_proto_depIdxs = []int32{
17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState
18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route 25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
32, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry 41, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry
0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
24, // 15: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State
1, // 16: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 24, // 16: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList
3, // 17: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest 1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
5, // 18: daemon.DaemonService.Up:input_type -> daemon.UpRequest 3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
7, // 19: daemon.DaemonService.Status:input_type -> daemon.StatusRequest 5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest
9, // 20: daemon.DaemonService.Down:input_type -> daemon.DownRequest 7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
11, // 21: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest 9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest
20, // 22: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest 11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
22, // 23: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest 20, // 23: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
22, // 24: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest 22, // 24: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
26, // 25: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest 22, // 25: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
28, // 26: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest 26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
30, // 27: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest 28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
2, // 28: daemon.DaemonService.Login:output_type -> daemon.LoginResponse 30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
4, // 29: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse 33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
6, // 30: daemon.DaemonService.Up:output_type -> daemon.UpResponse 35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
8, // 31: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
10, // 32: daemon.DaemonService.Down:output_type -> daemon.DownResponse 39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest
12, // 33: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse 2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
21, // 34: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse 4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
23, // 35: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse 6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse
23, // 36: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse 8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
27, // 37: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse
29, // 38: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
31, // 39: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse 21, // 39: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
28, // [28:40] is the sub-list for method output_type 23, // 40: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
16, // [16:28] is the sub-list for method input_type 23, // 41: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
16, // [16:16] is the sub-list for extension type_name 27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
16, // [16:16] is the sub-list for extension extendee 29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
0, // [0:16] is the sub-list for field type_name 31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
33, // [33:49] is the sub-list for method output_type
17, // [17:33] is the sub-list for method input_type
17, // [17:17] is the sub-list for extension type_name
17, // [17:17] is the sub-list for extension extendee
0, // [0:17] is the sub-list for field type_name
} }
func init() { file_daemon_proto_init() } func init() { file_daemon_proto_init() }
@@ -2938,6 +3434,114 @@ func file_daemon_proto_init() {
return nil return nil
} }
} }
file_daemon_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*State); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListStatesRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListStatesResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CleanStateRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CleanStateResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DeleteStateRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DeleteStateResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetNetworkMapPersistenceRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetNetworkMapPersistenceResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{} file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
type x struct{} type x struct{}
@@ -2946,7 +3550,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_daemon_proto_rawDesc, RawDescriptor: file_daemon_proto_rawDesc,
NumEnums: 1, NumEnums: 1,
NumMessages: 32, NumMessages: 41,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@@ -45,7 +45,20 @@ service DaemonService {
// SetLogLevel sets the log level of the daemon // SetLogLevel sets the log level of the daemon
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {} rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
};
// List all states
rpc ListStates(ListStatesRequest) returns (ListStatesResponse) {}
// Clean specific state or all states
rpc CleanState(CleanStateRequest) returns (CleanStateResponse) {}
// Delete specific state or all states
rpc DeleteState(DeleteStateRequest) returns (DeleteStateResponse) {}
// SetNetworkMapPersistence enables or disables network map persistence
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
}
message LoginRequest { message LoginRequest {
// setupKey wiretrustee setup key. // setupKey wiretrustee setup key.
@@ -293,4 +306,46 @@ message SetLogLevelRequest {
} }
message SetLogLevelResponse { message SetLogLevelResponse {
} }
// State represents a daemon state entry
message State {
string name = 1;
}
// ListStatesRequest is empty as it requires no parameters
message ListStatesRequest {}
// ListStatesResponse contains a list of states
message ListStatesResponse {
repeated State states = 1;
}
// CleanStateRequest for cleaning states
message CleanStateRequest {
string state_name = 1;
bool all = 2;
}
// CleanStateResponse contains the result of the clean operation
message CleanStateResponse {
int32 cleaned_states = 1;
}
// DeleteStateRequest for deleting states
message DeleteStateRequest {
string state_name = 1;
bool all = 2;
}
// DeleteStateResponse contains the result of the delete operation
message DeleteStateResponse {
int32 deleted_states = 1;
}
message SetNetworkMapPersistenceRequest {
bool enabled = 1;
}
message SetNetworkMapPersistenceResponse {}

View File

@@ -43,6 +43,14 @@ type DaemonServiceClient interface {
GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error)
// SetLogLevel sets the log level of the daemon // SetLogLevel sets the log level of the daemon
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
// List all states
ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error)
// Clean specific state or all states
CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error)
// Delete specific state or all states
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
// SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
} }
type daemonServiceClient struct { type daemonServiceClient struct {
@@ -161,6 +169,42 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe
return out, nil return out, nil
} }
func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) {
out := new(ListStatesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) {
out := new(CleanStateResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) {
out := new(DeleteStateResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) {
out := new(SetNetworkMapPersistenceResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service. // DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer // All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility // for forward compatibility
@@ -190,6 +234,14 @@ type DaemonServiceServer interface {
GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error) GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error)
// SetLogLevel sets the log level of the daemon // SetLogLevel sets the log level of the daemon
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
// List all states
ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error)
// Clean specific state or all states
CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error)
// Delete specific state or all states
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
// SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
mustEmbedUnimplementedDaemonServiceServer() mustEmbedUnimplementedDaemonServiceServer()
} }
@@ -233,6 +285,18 @@ func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLeve
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) { func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented") return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
} }
func (UnimplementedDaemonServiceServer) ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListStates not implemented")
}
func (UnimplementedDaemonServiceServer) CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method CleanState not implemented")
}
func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DeleteState not implemented")
}
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -462,6 +526,78 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListStatesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).ListStates(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ListStates",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CleanStateRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).CleanState(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/CleanState",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DeleteStateRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DeleteState(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DeleteState",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SetNetworkMapPersistenceRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -517,6 +653,22 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "SetLogLevel", MethodName: "SetLogLevel",
Handler: _DaemonService_SetLogLevel_Handler, Handler: _DaemonService_SetLogLevel_Handler,
}, },
{
MethodName: "ListStates",
Handler: _DaemonService_ListStates_Handler,
},
{
MethodName: "CleanState",
Handler: _DaemonService_CleanState_Handler,
},
{
MethodName: "DeleteState",
Handler: _DaemonService_DeleteState_Handler,
},
{
MethodName: "SetNetworkMapPersistence",
Handler: _DaemonService_SetNetworkMapPersistence_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "daemon.proto", Metadata: "daemon.proto",

View File

@@ -5,32 +5,44 @@ package server
import ( import (
"archive/zip" "archive/zip"
"bufio" "bufio"
"bytes"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"path/filepath"
"sort" "sort"
"strings" "strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/management/proto"
) )
const readmeContent = `Netbird debug bundle const readmeContent = `Netbird debug bundle
This debug bundle contains the following files: This debug bundle contains the following files:
status.txt: Anonymized status information of the NetBird client. status.txt: Anonymized status information of the NetBird client.
client.log: Most recent, anonymized log file of the NetBird client. client.log: Most recent, anonymized client log file of the NetBird client.
netbird.err: Most recent, anonymized stderr log file of the NetBird client.
netbird.out: Most recent, anonymized stdout log file of the NetBird client.
routes.txt: Anonymized system routes, if --system-info flag was provided. routes.txt: Anonymized system routes, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
config.txt: Anonymized configuration information of the NetBird client. config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states.
Anonymization Process Anonymization Process
@@ -50,8 +62,32 @@ Domains
All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle. All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
Reoccuring domain names are replaced with the same anonymized domain. Reoccuring domain names are replaced with the same anonymized domain.
Network Map
The network_map.json file contains the following anonymized information:
- Peer configurations (addresses, FQDNs, DNS settings)
- Remote and offline peer information (allowed IPs, FQDNs)
- Routes (network ranges, associated domains)
- DNS configuration (nameservers, domains, custom zones)
- Firewall rules (peer IPs, source/destination ranges)
SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above.
State File
The state.json file contains anonymized internal state information of the NetBird client, including:
- DNS settings and configuration
- Firewall rules
- Exclusion routes
- Route selection
- Other internal states that may be present
The state file follows the same anonymization rules as other files:
- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure
- Domain names are consistently anonymized
- Technical identifiers and non-sensitive data remain unchanged
Routes Routes
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
Network Interfaces Network Interfaces
The interfaces.txt file contains information about network interfaces, including: The interfaces.txt file contains information about network interfaces, including:
- Interface name - Interface name
@@ -72,6 +108,12 @@ The config.txt file contains anonymized configuration information of the NetBird
Other non-sensitive configuration options are included without anonymization. Other non-sensitive configuration options are included without anonymization.
` `
const (
clientLogFile = "client.log"
errorLogFile = "netbird.err"
stdoutLogFile = "netbird.out"
)
// DebugBundle creates a debug bundle and returns the location. // DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock() s.mutex.Lock()
@@ -119,19 +161,27 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
seedFromStatus(anonymizer, &status) seedFromStatus(anonymizer, &status)
if err := s.addConfig(req, anonymizer, archive); err != nil { if err := s.addConfig(req, anonymizer, archive); err != nil {
return fmt.Errorf("add config: %w", err) log.Errorf("Failed to add config to debug bundle: %v", err)
} }
if req.GetSystemInfo() { if req.GetSystemInfo() {
if err := s.addRoutes(req, anonymizer, archive); err != nil { if err := s.addRoutes(req, anonymizer, archive); err != nil {
return fmt.Errorf("add routes: %w", err) log.Errorf("Failed to add routes to debug bundle: %v", err)
} }
if err := s.addInterfaces(req, anonymizer, archive); err != nil { if err := s.addInterfaces(req, anonymizer, archive); err != nil {
return fmt.Errorf("add interfaces: %w", err) log.Errorf("Failed to add interfaces to debug bundle: %v", err)
} }
} }
if err := s.addNetworkMap(req, anonymizer, archive); err != nil {
return fmt.Errorf("add network map: %w", err)
}
if err := s.addStateFile(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add state file to debug bundle: %v", err)
}
if err := s.addLogfile(req, anonymizer, archive); err != nil { if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err) return fmt.Errorf("add log file: %w", err)
} }
@@ -220,15 +270,16 @@ func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
} }
func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
if routes, err := systemops.GetRoutesFromTable(); err != nil { routes, err := systemops.GetRoutesFromTable()
log.Errorf("Failed to get routes: %v", err) if err != nil {
} else { return fmt.Errorf("get routes: %w", err)
// TODO: get routes including nexthop }
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
routesReader := strings.NewReader(routesContent) // TODO: get routes including nexthop
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil { routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
return fmt.Errorf("add routes file to zip: %w", err) routesReader := strings.NewReader(routesContent)
} if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
} }
return nil return nil
} }
@@ -248,14 +299,106 @@ func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonym
return nil return nil
} }
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) { func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
logFile, err := os.Open(s.logFile) networkMap, err := s.getLatestNetworkMap()
if err != nil { if err != nil {
return fmt.Errorf("open log file: %w", err) // Skip if network map is not available, but log it
log.Debugf("skipping empty network map in debug bundle: %v", err)
return nil
}
if req.GetAnonymize() {
if err := anonymizeNetworkMap(networkMap, anonymizer); err != nil {
return fmt.Errorf("anonymize network map: %w", err)
}
}
options := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
Indent: " ",
AllowPartial: true,
}
jsonBytes, err := options.Marshal(networkMap)
if err != nil {
return fmt.Errorf("generate json: %w", err)
}
if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "network_map.json"); err != nil {
return fmt.Errorf("add network map to zip: %w", err)
}
return nil
}
func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
path := statemanager.GetDefaultStatePath()
if path == "" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return fmt.Errorf("read state file: %w", err)
}
if req.GetAnonymize() {
var rawStates map[string]json.RawMessage
if err := json.Unmarshal(data, &rawStates); err != nil {
return fmt.Errorf("unmarshal states: %w", err)
}
if err := anonymizeStateFile(&rawStates, anonymizer); err != nil {
return fmt.Errorf("anonymize state file: %w", err)
}
bs, err := json.MarshalIndent(rawStates, "", " ")
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}
data = bs
}
if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil {
return fmt.Errorf("add state file to zip: %w", err)
}
return nil
}
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
logDir := filepath.Dir(s.logFile)
if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil {
return fmt.Errorf("add client log file to zip: %w", err)
}
errLogPath := filepath.Join(logDir, errorLogFile)
if err := s.addSingleLogfile(errLogPath, errorLogFile, req, anonymizer, archive); err != nil {
log.Warnf("Failed to add %s to zip: %v", errorLogFile, err)
}
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil {
log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err)
}
return nil
}
// addSingleLogfile adds a single log file to the archive
func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
logFile, err := os.Open(logPath)
if err != nil {
return fmt.Errorf("open log file %s: %w", targetName, err)
} }
defer func() { defer func() {
if err := logFile.Close(); err != nil { if err := logFile.Close(); err != nil {
log.Errorf("Failed to close original log file: %v", err) log.Errorf("Failed to close log file %s: %v", targetName, err)
} }
}() }()
@@ -264,45 +407,55 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize
var writer *io.PipeWriter var writer *io.PipeWriter
logReader, writer = io.Pipe() logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, anonymizer) go anonymizeLog(logFile, writer, anonymizer)
} else { } else {
logReader = logFile logReader = logFile
} }
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
return fmt.Errorf("add log file to zip: %w", err) if err := addFileToZip(archive, logReader, targetName); err != nil {
return fmt.Errorf("add %s to zip: %w", targetName, err)
} }
return nil return nil
} }
func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { // getLatestNetworkMap returns the latest network map from the engine if network map persistence is enabled
defer func() { func (s *Server) getLatestNetworkMap() (*mgmProto.NetworkMap, error) {
// always nil if s.connectClient == nil {
_ = writer.Close() return nil, errors.New("connect client is not initialized")
}() }
scanner := bufio.NewScanner(reader) engine := s.connectClient.Engine()
for scanner.Scan() { if engine == nil {
line := anonymizer.AnonymizeString(scanner.Text()) return nil, errors.New("engine is not initialized")
if _, err := writer.Write([]byte(line + "\n")); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
return
}
} }
if err := scanner.Err(); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) networkMap, err := engine.GetLatestNetworkMap()
return if err != nil {
return nil, fmt.Errorf("get latest network map: %w", err)
} }
if networkMap == nil {
return nil, errors.New("network map is not available")
}
return networkMap, nil
} }
// GetLogLevel gets the current logging level for the server. // GetLogLevel gets the current logging level for the server.
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) { func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
level := ParseLogLevel(log.GetLevel().String()) level := ParseLogLevel(log.GetLevel().String())
return &proto.GetLogLevelResponse{Level: level}, nil return &proto.GetLogLevelResponse{Level: level}, nil
} }
// SetLogLevel sets the logging level for the server. // SetLogLevel sets the logging level for the server.
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) { func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
level, err := log.ParseLevel(req.Level.String()) level, err := log.ParseLevel(req.Level.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid log level: %w", err) return nil, fmt.Errorf("invalid log level: %w", err)
@@ -313,6 +466,20 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
return &proto.SetLogLevelResponse{}, nil return &proto.SetLogLevelResponse{}, nil
} }
// SetNetworkMapPersistence sets the network map persistence for the server.
func (s *Server) SetNetworkMapPersistence(_ context.Context, req *proto.SetNetworkMapPersistenceRequest) (*proto.SetNetworkMapPersistenceResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
enabled := req.GetEnabled()
s.persistNetworkMap = enabled
if s.connectClient != nil {
s.connectClient.SetNetworkMapPersistence(enabled)
}
return &proto.SetNetworkMapPersistenceResponse{}, nil
}
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error { func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{ header := &zip.FileHeader{
Name: filename, Name: filename,
@@ -458,6 +625,26 @@ func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *an
return builder.String() return builder.String()
} }
func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
defer func() {
// always nil
_ = writer.Close()
}()
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
return
}
}
if err := scanner.Err(); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
return
}
}
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string { func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
anonymizedIPs := make([]string, len(ips)) anonymizedIPs := make([]string, len(ips))
for i, ip := range ips { for i, ip := range ips {
@@ -484,3 +671,248 @@ func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []s
} }
return anonymizedIPs return anonymizedIPs
} }
func anonymizeNetworkMap(networkMap *mgmProto.NetworkMap, anonymizer *anonymize.Anonymizer) error {
if networkMap.PeerConfig != nil {
anonymizePeerConfig(networkMap.PeerConfig, anonymizer)
}
for _, peer := range networkMap.RemotePeers {
anonymizeRemotePeer(peer, anonymizer)
}
for _, peer := range networkMap.OfflinePeers {
anonymizeRemotePeer(peer, anonymizer)
}
for _, r := range networkMap.Routes {
anonymizeRoute(r, anonymizer)
}
if networkMap.DNSConfig != nil {
anonymizeDNSConfig(networkMap.DNSConfig, anonymizer)
}
for _, rule := range networkMap.FirewallRules {
anonymizeFirewallRule(rule, anonymizer)
}
for _, rule := range networkMap.RoutesFirewallRules {
anonymizeRouteFirewallRule(rule, anonymizer)
}
return nil
}
func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anonymizer) {
if config == nil {
return
}
if addr, err := netip.ParseAddr(config.Address); err == nil {
config.Address = anonymizer.AnonymizeIP(addr).String()
}
if config.SshConfig != nil && len(config.SshConfig.SshPubKey) > 0 {
config.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
}
config.Dns = anonymizer.AnonymizeString(config.Dns)
config.Fqdn = anonymizer.AnonymizeDomain(config.Fqdn)
}
func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize.Anonymizer) {
if peer == nil {
return
}
for i, ip := range peer.AllowedIps {
// Try to parse as prefix first (CIDR)
if prefix, err := netip.ParsePrefix(ip); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
} else if addr, err := netip.ParseAddr(ip); err == nil {
peer.AllowedIps[i] = anonymizer.AnonymizeIP(addr).String()
}
}
peer.Fqdn = anonymizer.AnonymizeDomain(peer.Fqdn)
if peer.SshConfig != nil && len(peer.SshConfig.SshPubKey) > 0 {
peer.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
}
}
func anonymizeRoute(route *mgmProto.Route, anonymizer *anonymize.Anonymizer) {
if route == nil {
return
}
if prefix, err := netip.ParsePrefix(route.Network); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
route.Network = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
for i, domain := range route.Domains {
route.Domains[i] = anonymizer.AnonymizeDomain(domain)
}
route.NetID = anonymizer.AnonymizeString(route.NetID)
}
func anonymizeDNSConfig(config *mgmProto.DNSConfig, anonymizer *anonymize.Anonymizer) {
if config == nil {
return
}
anonymizeNameServerGroups(config.NameServerGroups, anonymizer)
anonymizeCustomZones(config.CustomZones, anonymizer)
}
func anonymizeNameServerGroups(groups []*mgmProto.NameServerGroup, anonymizer *anonymize.Anonymizer) {
for _, group := range groups {
anonymizeServers(group.NameServers, anonymizer)
anonymizeDomains(group.Domains, anonymizer)
}
}
func anonymizeServers(servers []*mgmProto.NameServer, anonymizer *anonymize.Anonymizer) {
for _, server := range servers {
if addr, err := netip.ParseAddr(server.IP); err == nil {
server.IP = anonymizer.AnonymizeIP(addr).String()
}
}
}
func anonymizeDomains(domains []string, anonymizer *anonymize.Anonymizer) {
for i, domain := range domains {
domains[i] = anonymizer.AnonymizeDomain(domain)
}
}
func anonymizeCustomZones(zones []*mgmProto.CustomZone, anonymizer *anonymize.Anonymizer) {
for _, zone := range zones {
zone.Domain = anonymizer.AnonymizeDomain(zone.Domain)
anonymizeRecords(zone.Records, anonymizer)
}
}
func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
for _, record := range records {
record.Name = anonymizer.AnonymizeDomain(record.Name)
anonymizeRData(record, anonymizer)
}
}
func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
switch record.Type {
case 1, 28: // A or AAAA record
if addr, err := netip.ParseAddr(record.RData); err == nil {
record.RData = anonymizer.AnonymizeIP(addr).String()
}
default:
record.RData = anonymizer.AnonymizeString(record.RData)
}
}
func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.Anonymizer) {
if rule == nil {
return
}
if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
}
}
func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *anonymize.Anonymizer) {
if rule == nil {
return
}
for i, sourceRange := range rule.SourceRanges {
if prefix, err := netip.ParsePrefix(sourceRange); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
rule.SourceRanges[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
}
if prefix, err := netip.ParsePrefix(rule.Destination); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
}
func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error {
for name, rawState := range *rawStates {
if string(rawState) == "null" {
continue
}
var state map[string]any
if err := json.Unmarshal(rawState, &state); err != nil {
return fmt.Errorf("unmarshal state %s: %w", name, err)
}
state = anonymizeValue(state, anonymizer).(map[string]any)
bs, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal state %s: %w", name, err)
}
(*rawStates)[name] = bs
}
return nil
}
func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any {
switch v := value.(type) {
case string:
return anonymizeString(v, anonymizer)
case map[string]any:
return anonymizeMap(v, anonymizer)
case []any:
return anonymizeSlice(v, anonymizer)
}
return value
}
func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string {
if prefix, err := netip.ParsePrefix(v); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
if ip, err := netip.ParseAddr(v); err == nil {
return anonymizer.AnonymizeIP(ip).String()
}
return anonymizer.AnonymizeString(v)
}
func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any {
result := make(map[string]any, len(v))
for key, val := range v {
newKey := anonymizeMapKey(key, anonymizer)
result[newKey] = anonymizeValue(val, anonymizer)
}
return result
}
func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string {
if prefix, err := netip.ParsePrefix(key); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
if ip, err := netip.ParseAddr(key); err == nil {
return anonymizer.AnonymizeIP(ip).String()
}
return key
}
func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any {
for i, val := range v {
v[i] = anonymizeValue(val, anonymizer)
}
return v
}

430
client/server/debug_test.go Normal file
View File

@@ -0,0 +1,430 @@
package server
import (
"encoding/json"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func TestAnonymizeStateFile(t *testing.T) {
testState := map[string]json.RawMessage{
"null_state": json.RawMessage("null"),
"test_state": mustMarshal(map[string]any{
// Test simple fields
"public_ip": "203.0.113.1",
"private_ip": "192.168.1.1",
"protected_ip": "100.64.0.1",
"well_known_ip": "8.8.8.8",
"ipv6_addr": "2001:db8::1",
"private_ipv6": "fd00::1",
"domain": "test.example.com",
"uri": "stun:stun.example.com:3478",
"uri_with_ip": "turn:203.0.113.1:3478",
"netbird_domain": "device.netbird.cloud",
// Test CIDR ranges
"public_cidr": "203.0.113.0/24",
"private_cidr": "192.168.0.0/16",
"protected_cidr": "100.64.0.0/10",
"ipv6_cidr": "2001:db8::/32",
"private_ipv6_cidr": "fd00::/8",
// Test nested structures
"nested": map[string]any{
"ip": "203.0.113.2",
"domain": "nested.example.com",
"more_nest": map[string]any{
"ip": "203.0.113.3",
"domain": "deep.example.com",
},
},
// Test arrays
"string_array": []any{
"203.0.113.4",
"test1.example.com",
"test2.example.com",
},
"object_array": []any{
map[string]any{
"ip": "203.0.113.5",
"domain": "array1.example.com",
},
map[string]any{
"ip": "203.0.113.6",
"domain": "array2.example.com",
},
},
// Test multiple occurrences of same value
"duplicate_ip": "203.0.113.1", // Same as public_ip
"duplicate_domain": "test.example.com", // Same as domain
// Test URIs with various schemes
"stun_uri": "stun:stun.example.com:3478",
"turns_uri": "turns:turns.example.com:5349",
"http_uri": "http://web.example.com:80",
"https_uri": "https://secure.example.com:443",
// Test strings that might look like IPs but aren't
"not_ip": "300.300.300.300",
"partial_ip": "192.168",
"ip_like_string": "1234.5678",
// Test mixed content strings
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
// Test empty and special values
"empty_string": "",
"null_value": nil,
"numeric_value": 42,
"boolean_value": true,
}),
"route_state": mustMarshal(map[string]any{
"routes": []any{
map[string]any{
"network": "203.0.113.0/24",
"gateway": "203.0.113.1",
"domains": []any{
"route1.example.com",
"route2.example.com",
},
},
map[string]any{
"network": "2001:db8::/32",
"gateway": "2001:db8::1",
"domains": []any{
"route3.example.com",
"route4.example.com",
},
},
},
// Test map with IP/CIDR keys
"refCountMap": map[string]any{
"203.0.113.1/32": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"2001:db8::1/128": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "fe80::1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
},
},
},
}),
}
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Pre-seed the domains we need to verify in the test assertions
anonymizer.AnonymizeDomain("test.example.com")
anonymizer.AnonymizeDomain("nested.example.com")
anonymizer.AnonymizeDomain("deep.example.com")
anonymizer.AnonymizeDomain("array1.example.com")
err := anonymizeStateFile(&testState, anonymizer)
require.NoError(t, err)
// Helper function to unmarshal and get nested values
var state map[string]any
err = json.Unmarshal(testState["test_state"], &state)
require.NoError(t, err)
// Test null state remains unchanged
require.Equal(t, "null", string(testState["null_state"]))
// Basic assertions
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
assert.NotEqual(t, "test.example.com", state["domain"])
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
// CIDR ranges
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
// Nested structures
nested := state["nested"].(map[string]any)
assert.NotEqual(t, "203.0.113.2", nested["ip"])
assert.NotEqual(t, "nested.example.com", nested["domain"])
moreNest := nested["more_nest"].(map[string]any)
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
// Arrays
strArray := state["string_array"].([]any)
assert.NotEqual(t, "203.0.113.4", strArray[0])
assert.NotEqual(t, "test1.example.com", strArray[1])
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
objArray := state["object_array"].([]any)
firstObj := objArray[0].(map[string]any)
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
// Duplicate values should be anonymized consistently
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
assert.Equal(t, state["domain"], state["duplicate_domain"])
// URIs
assert.NotContains(t, state["stun_uri"], "stun.example.com")
assert.NotContains(t, state["turns_uri"], "turns.example.com")
assert.NotContains(t, state["http_uri"], "web.example.com")
assert.NotContains(t, state["https_uri"], "secure.example.com")
// Non-IP strings should remain unchanged
assert.Equal(t, "300.300.300.300", state["not_ip"])
assert.Equal(t, "192.168", state["partial_ip"])
assert.Equal(t, "1234.5678", state["ip_like_string"])
// Mixed content should have IPs and domains replaced
mixedContent := state["mixed_content"].(string)
assert.NotContains(t, mixedContent, "203.0.113.1")
assert.NotContains(t, mixedContent, "test.example.com")
assert.Contains(t, mixedContent, "Server at ")
assert.Contains(t, mixedContent, " on port 80")
// Special values should remain unchanged
assert.Equal(t, "", state["empty_string"])
assert.Nil(t, state["null_value"])
assert.Equal(t, float64(42), state["numeric_value"])
assert.Equal(t, true, state["boolean_value"])
// Check route state
var routeState map[string]any
err = json.Unmarshal(testState["route_state"], &routeState)
require.NoError(t, err)
routes := routeState["routes"].([]any)
route1 := routes[0].(map[string]any)
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
assert.Contains(t, route1["network"], "/24")
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
domains := route1["domains"].([]any)
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
// Check map keys are anonymized
refCountMap := routeState["refCountMap"].(map[string]any)
hasPublicIPKey := false
hasIPv6Key := false
hasPrivateIPKey := false
for key := range refCountMap {
if strings.Contains(key, "203.0.113.1") {
hasPublicIPKey = true
}
if strings.Contains(key, "2001:db8::1") {
hasIPv6Key = true
}
if key == "10.0.0.1/32" {
hasPrivateIPKey = true
}
}
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
}
func mustMarshal(v any) json.RawMessage {
data, err := json.Marshal(v)
if err != nil {
panic(err)
}
return data
}
func TestAnonymizeNetworkMap(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{
Address: "203.0.113.5",
Dns: "1.2.3.4",
Fqdn: "peer1.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
},
},
RemotePeers: []*mgmProto.RemotePeerConfig{
{
AllowedIps: []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
},
Fqdn: "peer2.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."),
},
},
},
Routes: []*mgmProto.Route{
{
Network: "197.51.100.0/24",
Domains: []string{"prod.example.com", "staging.example.com"},
NetID: "net-123abc",
},
},
DNSConfig: &mgmProto.DNSConfig{
NameServerGroups: []*mgmProto.NameServerGroup{
{
NameServers: []*mgmProto.NameServer{
{IP: "8.8.8.8"},
{IP: "1.1.1.1"},
{IP: "203.0.113.53"},
},
Domains: []string{"example.com", "internal.example.com"},
},
},
CustomZones: []*mgmProto.CustomZone{
{
Domain: "custom.example.com",
Records: []*mgmProto.SimpleRecord{
{
Name: "www.custom.example.com",
Type: 1,
RData: "203.0.113.10",
},
{
Name: "internal.custom.example.com",
Type: 1,
RData: "192.168.1.10",
},
},
},
},
},
}
// Create anonymizer with test addresses
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Anonymize the network map
err := anonymizeNetworkMap(networkMap, anonymizer)
require.NoError(t, err)
// Test PeerConfig anonymization
peerCfg := networkMap.PeerConfig
require.NotEqual(t, "203.0.113.5", peerCfg.Address)
// Verify DNS and FQDN are properly anonymized
require.NotEqual(t, "1.2.3.4", peerCfg.Dns)
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
// Verify SSH key is replaced
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
// Test RemotePeers anonymization
remotePeer := networkMap.RemotePeers[0]
// Verify FQDN is anonymized
require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn)
require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain"))
// Check that public IPs are anonymized but private IPs are preserved
for _, allowedIP := range remotePeer.AllowedIps {
ip, _, err := net.ParseCIDR(allowedIP)
require.NoError(t, err)
if ip.IsPrivate() || isInCGNATRange(ip) {
require.Contains(t, []string{
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
}, allowedIP)
} else {
require.NotContains(t, []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
}, allowedIP)
}
}
// Test Routes anonymization
route := networkMap.Routes[0]
require.NotEqual(t, "197.51.100.0/24", route.Network)
for _, domain := range route.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test DNS config anonymization
dnsConfig := networkMap.DNSConfig
nameServerGroup := dnsConfig.NameServerGroups[0]
// Verify well-known DNS servers are preserved
require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP)
require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP)
// Verify public DNS server is anonymized
require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP)
// Verify domains are anonymized
for _, domain := range nameServerGroup.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test CustomZones anonymization
customZone := dnsConfig.CustomZones[0]
require.True(t, strings.HasSuffix(customZone.Domain, ".domain"))
require.NotContains(t, customZone.Domain, "example.com")
// Verify records are properly anonymized
for _, record := range customZone.Records {
require.True(t, strings.HasSuffix(record.Name, ".domain"))
require.NotContains(t, record.Name, "example.com")
ip := net.ParseIP(record.RData)
if ip != nil {
if !ip.IsPrivate() {
require.NotEqual(t, "203.0.113.10", record.RData)
} else {
require.Equal(t, "192.168.1.10", record.RData)
}
}
}
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
return cgnat.Contains(ip)
}

View File

@@ -68,6 +68,8 @@ type Server struct {
relayProbe *internal.Probe relayProbe *internal.Probe
wgProbe *internal.Probe wgProbe *internal.Probe
lastProbe time.Time lastProbe time.Time
persistNetworkMap bool
} }
type oauthAuthFlow struct { type oauthAuthFlow struct {
@@ -196,6 +198,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error { runOperation := func() error {
log.Tracef("running client connection") log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetNetworkMapPersistence(s.persistNetworkMap)
probes := internal.ProbeHolder{ probes := internal.ProbeHolder{
MgmProbe: s.mgmProbe, MgmProbe: s.mgmProbe,

View File

@@ -5,12 +5,112 @@ import (
"fmt" "fmt"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/proto"
) )
// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. // ListStates returns a list of all saved states
func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*proto.ListStatesResponse, error) {
mgr := statemanager.New(statemanager.GetDefaultStatePath())
stateNames, err := mgr.GetSavedStateNames()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get saved state names: %v", err)
}
states := make([]*proto.State, 0, len(stateNames))
for _, name := range stateNames {
states = append(states, &proto.State{
Name: name,
})
}
return &proto.ListStatesResponse{
States: states,
}, nil
}
// CleanState handles cleaning of states (performing cleanup operations)
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}
if req.All {
// Reuse existing cleanup logic for all states
if err := restoreResidualState(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clean all states: %v", err)
}
// Get count of cleaned states
mgr := statemanager.New(statemanager.GetDefaultStatePath())
stateNames, err := mgr.GetSavedStateNames()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get state count: %v", err)
}
return &proto.CleanStateResponse{
CleanedStates: int32(len(stateNames)),
}, nil
}
// Handle single state cleanup
mgr := statemanager.New(statemanager.GetDefaultStatePath())
registerStates(mgr)
if err := mgr.CleanupStateByName(req.StateName); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clean state %s: %v", req.StateName, err)
}
if err := mgr.PersistState(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err)
}
return &proto.CleanStateResponse{
CleanedStates: 1,
}, nil
}
// DeleteState handles deletion of states without cleanup
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}
mgr := statemanager.New(statemanager.GetDefaultStatePath())
var count int
var err error
if req.All {
count, err = mgr.DeleteAllStates()
} else {
err = mgr.DeleteStateByName(req.StateName)
if err == nil {
count = 1
}
}
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete state: %v", err)
}
// Persist the changes
if err := mgr.PersistState(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err)
}
return &proto.DeleteStateResponse{
DeletedStates: int32(count),
}, nil
}
// restoreResidualState checks if the client was not shut down in a clean way and restores residual if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config. // Otherwise, we might not be able to connect to the management server to retrieve new config.
func restoreResidualState(ctx context.Context) error { func restoreResidualState(ctx context.Context) error {
path := statemanager.GetDefaultStatePath() path := statemanager.GetDefaultStatePath()
@@ -24,6 +124,7 @@ func restoreResidualState(ctx context.Context) error {
registerStates(mgr) registerStates(mgr)
var merr *multierror.Error var merr *multierror.Error
if err := mgr.PerformCleanup(); err != nil { if err := mgr.PerformCleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
} }

View File

@@ -61,6 +61,14 @@ type Info struct {
Files []File // for posture checks Files []File // for posture checks
} }
// StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string { func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx) md, hasMeta := metadata.FromOutgoingContext(ctx)

View File

@@ -10,13 +10,12 @@ import (
"os/exec" "os/exec"
"runtime" "runtime"
"strings" "strings"
"time"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -41,11 +40,10 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err) log.Warnf("failed to discover network addresses: %s", err)
} }
serialNum, prodName, manufacturer := sysInfo() start := time.Now()
si := updateStaticInfo()
env := Environment{ if time.Since(start) > 1*time.Second {
Cloud: detect_cloud.Detect(ctx), log.Warnf("updateStaticInfo took %s", time.Since(start))
Platform: detect_platform.Detect(ctx),
} }
gio := &Info{ gio := &Info{
@@ -57,10 +55,10 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
KernelVersion: release, KernelVersion: release,
NetworkAddresses: addrs, NetworkAddresses: addrs,
SystemSerialNumber: serialNum, SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: prodName, SystemProductName: si.SystemProductName,
SystemManufacturer: manufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: env, Environment: si.Environment,
} }
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()

View File

@@ -1,5 +1,4 @@
//go:build !android //go:build !android
// +build !android
package system package system
@@ -16,30 +15,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/zcalusic/sysinfo" "github.com/zcalusic/sysinfo"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
type SysInfoGetter interface { var (
GetSysInfo() SysInfo // it is override in tests
} getSystemInfo = defaultSysInfoImplementation
)
type SysInfoWrapper struct {
si sysinfo.SysInfo
}
func (s SysInfoWrapper) GetSysInfo() SysInfo {
s.si.GetSysInfo()
return SysInfo{
ChassisSerial: s.si.Chassis.Serial,
ProductSerial: s.si.Product.Serial,
BoardSerial: s.si.Board.Serial,
ProductName: s.si.Product.Name,
BoardName: s.si.Board.Name,
ProductVendor: s.si.Product.Vendor,
}
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
@@ -65,12 +47,10 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err) log.Warnf("failed to discover network addresses: %s", err)
} }
si := SysInfoWrapper{} start := time.Now()
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo()) si := updateStaticInfo()
if time.Since(start) > 1*time.Second {
env := Environment{ log.Warnf("updateStaticInfo took %s", time.Since(start))
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
} }
gio := &Info{ gio := &Info{
@@ -85,10 +65,10 @@ func GetInfo(ctx context.Context) *Info {
UIVersion: extractUserAgent(ctx), UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1], KernelVersion: osInfo[1],
NetworkAddresses: addrs, NetworkAddresses: addrs,
SystemSerialNumber: serialNum, SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: prodName, SystemProductName: si.SystemProductName,
SystemManufacturer: manufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: env, Environment: si.Environment,
} }
return gio return gio
@@ -108,9 +88,9 @@ func _getInfo() string {
return out.String() return out.String()
} }
func sysInfo(si SysInfo) (string, string, string) { func sysInfo() (string, string, string) {
isascii := regexp.MustCompile("^[[:ascii:]]+$") isascii := regexp.MustCompile("^[[:ascii:]]+$")
si := getSystemInfo()
serials := []string{si.ChassisSerial, si.ProductSerial} serials := []string{si.ChassisSerial, si.ProductSerial}
serial := "" serial := ""
@@ -141,3 +121,16 @@ func sysInfo(si SysInfo) (string, string, string) {
} }
return serial, name, manufacturer return serial, name, manufacturer
} }
func defaultSysInfoImplementation() SysInfo {
si := sysinfo.SysInfo{}
si.GetSysInfo()
return SysInfo{
ChassisSerial: si.Chassis.Serial,
ProductSerial: si.Product.Serial,
BoardSerial: si.Board.Serial,
ProductName: si.Product.Name,
BoardName: si.Board.Name,
ProductVendor: si.Product.Vendor,
}
}

View File

@@ -6,13 +6,12 @@ import (
"os" "os"
"runtime" "runtime"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi" "github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -42,24 +41,10 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err) log.Warnf("failed to discover network addresses: %s", err)
} }
serialNum, err := sysNumber() start := time.Now()
if err != nil { si := updateStaticInfo()
log.Warnf("failed to get system serial number: %s", err) if time.Since(start) > 1*time.Second {
} log.Warnf("updateStaticInfo took %s", time.Since(start))
prodName, err := sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err := sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
} }
gio := &Info{ gio := &Info{
@@ -71,10 +56,10 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
KernelVersion: buildVersion, KernelVersion: buildVersion,
NetworkAddresses: addrs, NetworkAddresses: addrs,
SystemSerialNumber: serialNum, SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: prodName, SystemProductName: si.SystemProductName,
SystemManufacturer: manufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: env, Environment: si.Environment,
} }
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()
@@ -85,6 +70,26 @@ func GetInfo(ctx context.Context) *Info {
return gio return gio
} }
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var err error
serialNumber, err = sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
productName, err = sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err = sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
return serialNumber, productName, manufacturer
}
func getOSNameAndVersion() (string, string) { func getOSNameAndVersion() (string, string) {
var dst []Win32_OperatingSystem var dst []Win32_OperatingSystem
query := wmi.CreateQuery(&dst, "") query := wmi.CreateQuery(&dst, "")

View File

@@ -0,0 +1,46 @@
//go:build (linux && !android) || windows || (darwin && !ios)
package system
import (
"context"
"sync"
"time"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
)
var (
staticInfo StaticInfo
once sync.Once
)
func init() {
go func() {
_ = updateStaticInfo()
}()
}
func updateStaticInfo() StaticInfo {
once.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
wg.Done()
}()
go func() {
staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
go func() {
staticInfo.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Wait()
})
return staticInfo
}

View File

@@ -183,7 +183,10 @@ func Test_sysInfo(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo) getSystemInfo = func() SysInfo {
return tt.sysInfo
}
gotSerialNum, gotProdName, gotManufacturer := sysInfo()
if gotSerialNum != tt.wantSerialNum { if gotSerialNum != tt.wantSerialNum {
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum) t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
} }

6
go.mod
View File

@@ -25,7 +25,7 @@ require (
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.64.1 google.golang.org/grpc v1.64.1
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.2
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
) )
@@ -224,7 +224,7 @@ require (
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

12
go.sum
View File

@@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g= github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -1151,8 +1151,8 @@ google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaE
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U= google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U=
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 h1:AgADTJarZTBqgjiUzRgfaBchgYB3/WFTC80GPwsMcRI= google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
@@ -1189,8 +1189,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -530,7 +530,7 @@ renderCaddyfile() {
{ {
debug debug
servers :80,:443 { servers :80,:443 {
protocols h1 h2c protocols h1 h2c h3
} }
} }
@@ -788,6 +788,7 @@ services:
networks: [ netbird ] networks: [ netbird ]
ports: ports:
- '443:443' - '443:443'
- '443:443/udp'
- '80:80' - '80:80'
- '8080:8080' - '8080:8080'
volumes: volumes:

View File

@@ -113,7 +113,7 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@@ -139,7 +139,7 @@ type AccountManager interface {
HasConnectedChannel(peerID string) bool HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager GetIdpManager() idp.Manager

View File

@@ -6,13 +6,17 @@ import (
b64 "encoding/base64" b64 "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net" "net"
"os"
"reflect" "reflect"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -1038,7 +1042,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
} }
b.Run("public without account ID", func(b *testing.B) { b.Run("public without account ID", func(b *testing.B) {
//b.ResetTimer() // b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims) _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil { if err != nil {
@@ -1048,7 +1052,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
}) })
b.Run("private without account ID", func(b *testing.B) { b.Run("private without account ID", func(b *testing.B) {
//b.ResetTimer() // b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil { if err != nil {
@@ -1059,7 +1063,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
b.Run("private with account ID", func(b *testing.B) { b.Run("private with account ID", func(b *testing.B) {
claims.AccountId = id claims.AccountId = id
//b.ResetTimer() // b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil { if err != nil {
@@ -1238,8 +1242,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
return return
} }
policy := Policy{ _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1250,8 +1253,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1320,19 +1322,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
policy := Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
go func() { go func() {
@@ -1345,7 +1334,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
} }
}() }()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
})
if err != nil {
t.Errorf("delete default rule: %v", err) t.Errorf("delete default rule: %v", err)
return return
} }
@@ -1366,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return return
} }
policy := Policy{ _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1377,9 +1378,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
if err != nil {
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("save policy: %v", err) t.Errorf("save policy: %v", err)
return return
} }
@@ -1421,7 +1421,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
require.NoError(t, err, "failed to save group") require.NoError(t, err, "failed to save group")
policy := Policy{ if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1432,14 +1437,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
if err != nil {
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("save policy: %v", err) t.Errorf("save policy: %v", err)
return return
} }
@@ -2987,3 +2986,218 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage)
t.Error("Timed out waiting for update message") t.Error("Timed out waiting for update message")
} }
} }
func BenchmarkSyncAndMarkPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 1, 3, 4, 10},
{"Medium", 500, 100, 7, 13, 10, 60},
{"Large", 5000, 200, 65, 80, 60, 170},
{"Small single", 50, 10, 1, 3, 4, 60},
{"Medium single", 500, 10, 7, 13, 10, 26},
{"Large 5", 5000, 15, 65, 80, 60, 170},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 102, 110, 102, 120},
{"Medium", 500, 100, 105, 140, 105, 170},
{"Large", 5000, 200, 160, 200, 160, 270},
{"Small single", 50, 10, 102, 110, 102, 120},
{"Medium single", 500, 10, 105, 140, 105, 170},
{"Large 5", 5000, 15, 160, 200, 160, 270},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
WireGuardPubKey: account.Peers["peer-1"].Key,
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
UserID: "regular_user",
SetupKey: "",
ConnectionIP: net.IP{1, 1, 1, 1},
})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
func BenchmarkLoginPeer_NewPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 107, 120, 107, 140},
{"Medium", 500, 100, 105, 140, 105, 170},
{"Large", 5000, 200, 180, 220, 180, 320},
{"Small single", 50, 10, 107, 120, 105, 140},
{"Medium single", 500, 10, 105, 140, 105, 170},
{"Large 5", 5000, 15, 180, 220, 180, 320},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
UserID: "regular_user",
SetupKey: "",
ConnectionIP: net.IP{1, 1, 1, 1},
})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"slices"
"strconv" "strconv"
"sync" "sync"
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
@@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
if dnsSettingsToSave == nil { if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
} }
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) if err != nil {
if err != nil {
return err
}
}
oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
for _, id := range addedGroups { if user.AccountID != accountID {
group := account.GetGroup(id) return status.NewUserNotPartOfAccountError()
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
} }
for _, id := range removedGroups { if !user.HasAdminPower() {
group := account.GetGroup(id) return status.NewAdminPermissionError()
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
} }
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err
}
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return err
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil
}
for _, groupID := range addedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
})
}
for _, groupID := range removedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
})
}
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {
return nil
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
if err != nil {
return err
}
return validateGroups(settings.DisabledManagementGroups, groups)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ protoUpdate := &proto.DNSConfig{

View File

@@ -566,8 +566,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
return false, nil return false, nil
} }
// anyGroupHasPeers checks if any of the given groups in the account have peers. func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() { if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true return true
@@ -575,3 +574,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
} }
return false return false
} }
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
if err != nil {
return false, err
}
for _, group := range groups {
if group.HasPeers() {
return true, nil
}
}
return false, nil
}

View File

@@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}) })
// adding a group to policy // adding a group to policy
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
}, false) })
assert.NoError(t, err) assert.NoError(t, err)
// Saving a group linked to policy should update account peers and send peer update // Saving a group linked to policy should update account peers and send peer update

View File

@@ -6,10 +6,8 @@ import (
"strconv" "strconv"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return return
} }
isUpdate := policyID != "" policy := &server.Policy{
if policyID == "" {
policyID = xid.New().String()
}
policy := server.Policy{
ID: policyID, ID: policyID,
AccountID: accountID,
Name: req.Name, Name: req.Name,
Enabled: req.Enabled, Enabled: req.Enabled,
Description: req.Description, Description: req.Description,
} }
for _, rule := range req.Rules { for _, rule := range req.Rules {
var ruleID string
if rule.Id != nil {
ruleID = *rule.Id
}
pr := server.PolicyRule{ pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor ID: ruleID,
PolicyID: policyID,
Name: rule.Name, Name: rule.Name,
Destinations: rule.Destinations, Destinations: rule.Destinations,
Sources: rule.Sources, Sources: rule.Sources,
@@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
policy.SourcePostureChecks = *req.SourcePostureChecks policy.SourcePostureChecks = *req.SourcePostureChecks
} }
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
@@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return return
} }
resp := toPolicyResponse(allGroups, &policy) resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 { if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return return

View File

@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
} }
return policy, nil return policy, nil
}, },
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
if !strings.HasPrefix(policy.ID, "id-") { if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set" policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set" policy.Rules[0].ID = "id-was-set"
} }
return nil return policy, nil
}, },
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil

View File

@@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
return return
} }
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
} }
return p, nil return p, nil
}, },
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
postureChecks.ID = "postureCheck" postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil { if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
} }
return nil return postureChecks, nil
}, },
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID] _, ok := testPostureChecks[postureChecksID]

View File

@@ -77,6 +77,8 @@ type JWTValidator struct {
options Options options Options
} }
var keyNotFound = errors.New("unable to find appropriate key")
// NewJWTValidator constructor // NewJWTValidator constructor
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(ctx, keysLocation) keys, err := getPemKeys(ctx, keysLocation)
@@ -124,12 +126,18 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
} }
publicKey, err := getPublicKey(ctx, token, keys) publicKey, err := getPublicKey(ctx, token, keys)
if err != nil { if err == nil {
log.WithContext(ctx).Errorf("getPublicKey error: %s", err) return publicKey, nil
return nil, err
} }
return publicKey, nil msg := fmt.Sprintf("getPublicKey error: %s", err)
if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled {
msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
}
log.WithContext(ctx).Error(msg)
return nil, err
}, },
EnableAuthOnOptions: false, EnableAuthOnOptions: false,
} }
@@ -229,7 +237,7 @@ func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{
log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty) log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
} }
return nil, errors.New("unable to find appropriate key") return nil, keyNotFound
} }
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) { func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
@@ -310,4 +318,3 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
return 0 return 0
} }

View File

@@ -49,7 +49,7 @@ type MockAccountManager struct {
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
@@ -96,7 +96,7 @@ type MockAccountManager struct {
HasConnectedChannelFunc func(peerID string) bool HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager GetIdpManagerFunc func() idp.Manager
@@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
} }
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface // SavePolicy mock implementation of SavePolicy from server.AccountManager interface
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
if am.SavePolicyFunc != nil { if am.SavePolicyFunc != nil {
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) return am.SavePolicyFunc(ctx, accountID, userID, policy)
} }
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
} }
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
@@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
} }
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface // SavePostureChecks mocks SavePostureChecks of the AccountManager interface
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
if am.SavePostureChecksFunc != nil { if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
} }
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
} }
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface

View File

@@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") return nil, status.NewUserNotPartOfAccountError()
} }
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
newNSGroup := &nbdns.NameServerGroup{ newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(), ID: xid.New().String(),
AccountID: accountID,
Name: name, Name: name,
Description: description, Description: description,
NameServers: nameServerList, NameServers: nameServerList,
@@ -54,26 +62,33 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled, SearchDomainsEnabled: searchDomainEnabled,
} }
err = validateNameServerGroup(false, newNSGroup, account) var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if account.NameServerGroups == nil { am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
}
account.NameServerGroups[newNSGroup.ID] = newNSGroup if updateAccountPeers {
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
} }
@@ -87,58 +102,95 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
} }
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
err = validateNameServerGroup(true, nsGroupToSave, account) if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
nsGroupToSave.AccountID = accountID
if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil {
return err
}
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
})
if err != nil { if err != nil {
return err return err
} }
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID] am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
account.Network.IncSerial() if updateAccountPeers {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
return nil return nil
} }
// DeleteNameServerGroup deletes nameserver group with nsGroupID // DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
nsGroup := account.NameServerGroups[nsGroupID] if user.AccountID != accountID {
if nsGroup == nil { return status.NewUserNotPartOfAccountError()
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
} }
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial() var nsGroup *nbdns.NameServerGroup
if err = am.Store.SaveAccount(ctx, account); err != nil { var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
if err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
})
if err != nil {
return err return err
} }
if anyGroupHasPeers(account, nsGroup.Groups) { am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
return nil return nil
} }
@@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
} }
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
if err != nil { if err != nil {
return err return err
} }
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
err = validateNSList(nameserverGroup.NameServers) err = validateNSList(nameserverGroup.NameServers)
if err != nil { if err != nil {
return err return err
} }
err = validateGroups(nameserverGroup.Groups, account.Groups) nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
return nil err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
if err != nil {
return err
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
if err != nil {
return err
}
return validateGroups(nameserverGroup.Groups, groups)
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
} }
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
@@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
return nil return nil
} }
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
} }
for _, nsGroup := range nsGroupMap { for _, nsGroup := range groups {
if name == nsGroup.Name && nsGroup.ID != nsGroupID { if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
} }
} }
@@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
} }
func validateNSList(list []nbdns.NameServer) error { func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list) nsListLength := len(list)
if nsListLenght == 0 || nsListLenght > 3 { if nsListLength == 0 || nsListLength > 3 {
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list)) return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
} }
return nil return nil
@@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
if id == "" { if id == "" {
return status.Errorf(status.InvalidArgument, "group ID should not be empty string") return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
} }
found := false if _, found := groups[id]; !found {
for groupID := range groups {
if id == groupID {
found = true
break
}
}
if !found {
return status.Errorf(status.InvalidArgument, "group id %s not found", id) return status.Errorf(status.InvalidArgument, "group id %s not found", id)
} }
} }
@@ -277,11 +340,3 @@ func validateDomain(domain string) error {
return nil return nil
} }
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
}

View File

@@ -477,7 +477,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
setupKeyName = sk.Name setupKeyName = sk.Name
} }
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
if am.idpManager != nil { if am.idpManager != nil {
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err == nil && userdata != nil { if err == nil && userdata != nil {
@@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks := am.getPeerPostureChecks(account, newPeer) postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil return newPeer, networkMap, postureChecks, nil
@@ -674,10 +678,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err) return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
} }
if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account.Id)
}
} }
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
@@ -685,24 +685,26 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err) return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err)
} }
var postureChecks []*posture.Checks postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) {
am.updateAccountPeers(ctx, account.Id)
}
if peerNotValid { if peerNotValid {
emptyMap := &NetworkMap{ emptyMap := &NetworkMap{
Network: account.Network.Copy(), Network: account.Network.Copy(),
} }
return peer, emptyMap, postureChecks, nil return peer, emptyMap, []*posture.Checks{}, nil
}
if isStatusChanged {
am.updateAccountPeers(ctx, account.Id)
} }
validPeersMap, err := am.GetValidatedPeers(account) validPeersMap, err := am.GetValidatedPeers(account)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
} }
postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@@ -876,7 +878,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks = am.getPeerPostureChecks(account, peer)
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@@ -1030,7 +1036,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer wg.Done() defer wg.Done()
defer func() { <-semaphore }() defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p) postureChecks, err := am.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})

View File

@@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
var ( var (
group1 nbgroup.Group group1 nbgroup.Group
group2 nbgroup.Group group2 nbgroup.Group
policy Policy
) )
group1.ID = xid.New().String() group1.ID = xid.New().String()
group2.ID = xid.New().String() group2.ID = xid.New().String()
group1.Name = "src" group1.Name = "src"
group2.Name = "dst" group2.Name = "dst"
policy.ID = xid.New().String()
group1.Peers = append(group1.Peers, peer1.ID) group1.Peers = append(group1.Peers, peer1.ID)
group2.Peers = append(group2.Peers, peer2.ID) group2.Peers = append(group2.Peers, peer2.ID)
@@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return return
} }
policy.Name = "test" policy := &Policy{
policy.Enabled = true Name: "test",
policy.Rules = []*PolicyRule{ Enabled: true,
{ Rules: []*PolicyRule{
Enabled: true, {
Sources: []string{group1.ID}, Enabled: true,
Destinations: []string{group2.ID}, Sources: []string{group1.ID},
Bidirectional: true, Destinations: []string{group2.ID},
Action: PolicyTrafficActionAccept, Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
}, },
} }
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil { if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err) t.Errorf("expecting rule to be added, got failure %v", err)
return return
@@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
} }
policy.Enabled = false policy.Enabled = false
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil { if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err) t.Errorf("expecting rule to be added, got failure %v", err)
return return
@@ -833,19 +833,23 @@ func BenchmarkGetPeers(b *testing.B) {
}) })
} }
} }
func BenchmarkUpdateAccountPeers(b *testing.B) { func BenchmarkUpdateAccountPeers(b *testing.B) {
benchCases := []struct { benchCases := []struct {
name string name string
peers int peers int
groups int groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5}, {"Small", 50, 5, 90, 120, 90, 120},
{"Medium", 500, 10}, {"Medium", 500, 100, 110, 140, 120, 200},
{"Large", 5000, 20}, {"Large", 5000, 200, 800, 1300, 2500, 3600},
{"Small single", 50, 1}, {"Small single", 50, 10, 90, 120, 90, 120},
{"Medium single", 500, 1}, {"Medium single", 500, 10, 110, 170, 120, 200},
{"Large 5", 5000, 5}, {"Large 5", 5000, 15, 1300, 1800, 5000, 6000},
} }
log.SetOutput(io.Discard) log.SetOutput(io.Discard)
@@ -881,8 +885,23 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
} }
duration := time.Since(start) duration := time.Since(start)
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op") msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(0, "ns/op") b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
}) })
} }
} }
@@ -1445,8 +1464,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
// Adding peer to group linked with policy should update account peers and send peer update // Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) { t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1457,7 +1475,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
}, false) })
require.NoError(t, err) require.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})

View File

@@ -3,13 +3,13 @@ package server
import ( import (
"context" "context"
_ "embed" _ "embed"
"slices"
"strconv" "strconv"
"strings" "strings"
"github.com/netbirdio/netbird/management/proto"
"github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -125,6 +125,7 @@ type PolicyRule struct {
func (pm *PolicyRule) Copy() *PolicyRule { func (pm *PolicyRule) Copy() *PolicyRule {
rule := &PolicyRule{ rule := &PolicyRule{
ID: pm.ID, ID: pm.ID,
PolicyID: pm.PolicyID,
Name: pm.Name, Name: pm.Name,
Description: pm.Description, Description: pm.Description,
Enabled: pm.Enabled, Enabled: pm.Enabled,
@@ -171,6 +172,7 @@ type Policy struct {
func (p *Policy) Copy() *Policy { func (p *Policy) Copy() *Policy {
c := &Policy{ c := &Policy{
ID: p.ID, ID: p.ID,
AccountID: p.AccountID,
Name: p.Name, Name: p.Name,
Description: p.Description, Description: p.Description,
Enabled: p.Enabled, Enabled: p.Enabled,
@@ -343,157 +345,207 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") return nil, status.NewUserNotPartOfAccountError()
} }
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
} }
// SavePolicy in the store // SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return nil, err
} }
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var isUpdate = policy.ID != ""
var updateAccountPeers bool
var action = activity.PolicyAdded
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
saveFunc := transaction.CreatePolicy
if isUpdate {
action = activity.PolicyUpdated
saveFunc = transaction.SavePolicy
}
return saveFunc(ctx, LockingStrengthUpdate, policy)
})
if err != nil { if err != nil {
return err return nil, err
} }
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
action := activity.PolicyAdded
if isUpdate {
action = activity.PolicyUpdated
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
return policy, nil
}
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var policy *Policy
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID)
if err != nil {
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// DeletePolicy from the store // ListPolicies from the store.
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
policy, err := am.deletePolicy(account, policyID)
if err != nil {
return err
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
if anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
} }
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
policyIdx := -1 func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
}
policy := account.Policies[policyIdx]
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
return policy, nil
}
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
if isUpdate { if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
if policyIdx < 0 { if err != nil {
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) return false, err
} }
oldPolicy := account.Policies[policyIdx] if !policy.Enabled && !existingPolicy.Enabled {
// Update the existing policy
account.Policies[policyIdx] = policyToSave
if !policyToSave.Enabled && !oldPolicy.Enabled {
return false, nil return false, nil
} }
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
return updateAccountPeers, nil hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
} if err != nil {
return false, err
}
// Add the new policy to the account if hasPeers {
account.Policies = append(account.Policies, policyToSave) return true, nil
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
} }
} }
return result
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
}
// validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
if err != nil {
return err
}
} else {
policy.ID = xid.New().String()
policy.AccountID = accountID
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
if err != nil {
return err
}
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
if err != nil {
return err
}
for i, rule := range policy.Rules {
ruleCopy := rule.Copy()
if ruleCopy.ID == "" {
ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor
ruleCopy.PolicyID = policy.ID
}
ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources)
ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations)
policy.Rules[i] = ruleCopy
}
if policy.SourcePostureChecks != nil {
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
}
return nil
} }
// getAllPeersFromGroups for given peer ID and list of groups // getAllPeersFromGroups for given peer ID and list of groups
@@ -574,27 +626,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
return nil return nil
} }
// filterValidPostureChecks filters and returns the posture check IDs from the given list // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
// that are valid within the provided account. func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { validIDs := make([]string, 0, len(postureChecksIds))
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds { for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks { if _, exists := postureChecks[id]; exists {
if id == postureCheck.ID { validIDs = append(validIDs, id)
result = append(result, id)
continue
}
} }
} }
return result
return validIDs
} }
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. // getValidGroupIDs filters and returns only the valid group IDs from the provided list.
func filterValidGroupIDs(account *Account, groupIDs []string) []string { func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
result := make([]string, 0, len(groupIDs)) validIDs := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs { for _, id := range groupIDs {
if _, exists := account.Groups[groupID]; exists { if _, exists := groups[id]; exists {
result = append(result, groupID) validIDs = append(validIDs, id)
}
}
return validIDs
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
} }
} }
return result return result

View File

@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
@@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
}) })
var policyWithGroupRulesNoPeers *Policy
var policyWithDestinationPeersOnly *Policy
var policyWithSourceAndDestinationPeers *Policy
// Saving policy with rule groups with no peers should not update account's peers and not send peer update // Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) { t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
policy := Policy{ done := make(chan struct{})
ID: "policy-rule-groups-no-peers", go func() {
Enabled: true, peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupB"}, Sources: []string{"groupB"},
Destinations: []string{"groupC"}, Destinations: []string{"groupC"},
@@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with source group containing peers, but destination group without peers should // Saving policy with source group containing peers, but destination group without peers should
// update account's peers and send peer update // update account's peers and send peer update
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
policy := Policy{ done := make(chan struct{})
ID: "policy-source-has-peers-destination-none", go func() {
Enabled: true, peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupB"}, Destinations: []string{"groupB"},
@@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with destination group containing peers, but source group without peers should // Saving policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update // update account's peers and send peer update
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
policy := Policy{ done := make(chan struct{})
ID: "policy-destination-has-peers-source-none", go func() {
Enabled: true, peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(), Enabled: true,
Enabled: false,
Sources: []string{"groupC"}, Sources: []string{"groupC"},
Destinations: []string{"groupD"}, Destinations: []string{"groupD"},
Bidirectional: true, Bidirectional: true,
@@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with destination and source groups containing peers should update account's peers // Saving policy with destination and source groups containing peers should update account's peers
// and send peer update // and send peer update
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{ done := make(chan struct{})
ID: "policy-source-destination-peers", go func() {
Enabled: true, peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupD"}, Destinations: []string{"groupD"},
@@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
} })
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Disabling policy with destination and source groups containing peers should update account's peers // Disabling policy with destination and source groups containing peers should update account's peers
// and send peer update // and send peer update
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) policyWithSourceAndDestinationPeers.Enabled = false
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Updating disabled policy with destination and source groups containing peers should not update account's peers // Updating disabled policy with destination and source groups containing peers should not update account's peers
// or send peer update // or send peer update
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Description: "updated description",
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldNotReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) policyWithSourceAndDestinationPeers.Description = "updated description"
policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"}
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Enabling policy with destination and source groups containing peers should update account's peers // Enabling policy with destination and source groups containing peers should update account's peers
// and send peer update // and send peer update
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) policyWithSourceAndDestinationPeers.Enabled = true
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy should trigger account peers update and send peer update // Deleting policy should trigger account peers update and send peer update
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
policyID := "policy-source-destination-peers"
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy with destination group containing peers, but source group without peers should // Deleting policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update // update account's peers and send peer update
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
policyID := "policy-destination-has-peers-source-none"
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy with no peers in groups should not update account's peers and not send peer update // Deleting policy with no peers in groups should not update account's peers and not send peer update
t.Run("deleting policy with no peers in groups", func(t *testing.T) { t.Run("deleting policy with no peers in groups", func(t *testing.T) {
policyID := "policy-rule-groups-no-peers"
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldNotReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID)
assert.NoError(t, err) assert.NoError(t, err)
select { select {

View File

@@ -7,8 +7,6 @@ import (
"regexp" "regexp"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
} }
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) { func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
postureChecks := Checks{ postureChecks := Checks{
ID: postureChecksID, ID: postureChecksID,
Name: name, Name: name,

View File

@@ -2,237 +2,285 @@ package server
import ( import (
"context" "context"
"errors"
"fmt"
"slices" "slices"
"github.com/rs/xid"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
const (
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
} }
if !user.HasAdminPower() { if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.NewAdminPermissionError()
} }
if err := postureChecks.Validate(); err != nil { return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
return status.Errorf(status.InvalidArgument, err.Error()) //nolint }
// SavePostureChecks saves a posture check.
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
} }
exists, uniqName := am.savePostureChecks(account, postureChecks) if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
// we do not allow create new posture checks with non uniq name
if !exists && !uniqName {
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
} }
action := activity.PostureCheckCreated if !user.HasAdminPower() {
if exists { return nil, status.NewAdminPermissionError()
action = activity.PostureCheckUpdated
account.Network.IncSerial()
} }
if err = am.Store.SaveAccount(ctx, account); err != nil { var updateAccountPeers bool
return err var isUpdate = postureChecks.ID != ""
var action = activity.PostureCheckCreated
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
return err
}
if isUpdate {
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
action = activity.PostureCheckUpdated
}
postureChecks.AccountID = accountID
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
})
if err != nil {
return nil, err
} }
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { if updateAccountPeers {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
return nil return postureChecks, nil
} }
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
user, err := account.FindUser(userID) if user.AccountID != accountID {
if err != nil { return status.NewUserNotPartOfAccountError()
return err
} }
if !user.HasAdminPower() { if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return status.NewAdminPermissionError()
} }
postureChecks, err := am.deletePostureChecks(account, postureChecksID) var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
if err != nil {
return err
}
if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
})
if err != nil { if err != nil {
return err return err
} }
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
return nil return nil
} }
// ListPostureChecks returns a list of posture checks.
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() || user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return nil, status.NewAdminPermissionError()
} }
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
} }
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
uniqName = true
for i, p := range account.PostureChecks {
if !exists && p.ID == postureChecks.ID {
account.PostureChecks[i] = postureChecks
exists = true
}
if p.Name == postureChecks.Name {
uniqName = false
}
}
if !exists {
account.PostureChecks = append(account.PostureChecks, postureChecks)
}
return
}
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
postureChecksIdx := -1
for i, postureChecks := range account.PostureChecks {
if postureChecks.ID == postureChecksID {
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
}
// Check if posture check is linked to any policy
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
}
postureChecks := account.PostureChecks[postureChecksIdx]
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
return postureChecks, nil
}
// getPeerPostureChecks returns the posture checks applied for a given peer. // getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]posture.Checks) peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 { if len(account.PostureChecks) == 0 {
return nil return nil, nil
} }
for _, policy := range account.Policies { for _, policy := range account.Policies {
if !policy.Enabled { if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
continue continue
} }
if isPeerInPolicySourceGroups(peer.ID, account, policy) { if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
addPolicyPostureChecks(account, policy, peerPostureChecks) return nil, err
} }
} }
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) return maps.Values(peerPostureChecks), nil
for _, check := range peerPostureChecks { }
checkCopy := check
postureChecksList = append(postureChecksList, &checkCopy) // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
} }
return postureChecksList for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
}
}
return false, nil
}
// validatePostureChecks validates the posture checks.
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
// If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
return err
}
return nil
}
// For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
for _, check := range checks {
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
}
}
postureChecks.ID = xid.New().String()
return nil
}
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
}
if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck := account.getPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
} }
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if !rule.Enabled { if !rule.Enabled {
continue continue
} }
for _, sourceGroup := range rule.Sources { for _, sourceGroup := range rule.Sources {
group, ok := account.Groups[sourceGroup] group := account.GetGroup(sourceGroup)
if ok && slices.Contains(group.Peers, peerID) { if group == nil {
return true return false, fmt.Errorf("failed to check peer in policy source group: group not found")
}
if slices.Contains(group.Peers, peerID) {
return true, nil
} }
} }
} }
return false
}
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
for _, postureCheck := range account.PostureChecks {
if postureCheck.ID == sourcePostureCheckID {
peerPostureChecks[sourcePostureCheckID] = *postureCheck
}
}
}
}
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
for _, policy := range account.Policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return true, policy
}
}
return false, nil return false, nil
} }
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
if !exists { policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
return false if err != nil {
return err
} }
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) for _, policy := range policies {
if !isLinked { if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return false return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
}
} }
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
return nil
} }

View File

@@ -5,8 +5,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/group"
@@ -16,7 +16,6 @@ import (
const ( const (
adminUserID = "adminUserID" adminUserID = "adminUserID"
regularUserID = "regularUserID" regularUserID = "regularUserID"
postureCheckID = "existing-id"
postureCheckName = "Existing check" postureCheckName = "Existing check"
) )
@@ -33,7 +32,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
t.Run("Generic posture check flow", func(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) {
// regular users can not create checks // regular users can not create checks
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
assert.Error(t, err) assert.Error(t, err)
// regular users cannot list check // regular users cannot list check
@@ -41,8 +40,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
// should be possible to create posture check with uniq name // should be possible to create posture check with uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: postureCheckID,
Name: postureCheckName, Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
@@ -58,8 +56,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Len(t, checks, 1) assert.Len(t, checks, 1)
// should not be possible to create posture check with non uniq name // should not be possible to create posture check with non uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: "new-id",
Name: postureCheckName, Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{ GeoLocationCheck: &posture.GeoLocationCheck{
@@ -74,23 +71,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
// admins can update posture checks // admins can update posture checks
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ postureCheck.Checks = posture.ChecksDefinition{
ID: postureCheckID, NBVersionCheck: &posture.NBVersionCheck{
Name: postureCheckName, MinVersion: "0.27.0",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.27.0",
},
}, },
}) }
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck)
assert.NoError(t, err) assert.NoError(t, err)
// users should not be able to delete posture checks // users should not be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID)
assert.Error(t, err) assert.Error(t, err)
// admin should be able to delete posture checks // admin should be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
@@ -150,9 +144,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
}) })
postureCheck := posture.Checks{ postureCheckA := &posture.Checks{
ID: "postureCheck", Name: "postureCheckA",
Name: "postureCheck", AccountID: account.Id,
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
},
},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
require.NoError(t, err)
postureCheckB := &posture.Checks{
Name: "postureCheckB",
AccountID: account.Id, AccountID: account.Id,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
@@ -169,7 +176,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -187,12 +194,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -202,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
} }
}) })
policy := Policy{ policy := &Policy{
ID: "policyA",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -215,7 +220,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
SourcePostureChecks: []string{postureCheck.ID}, SourcePostureChecks: []string{postureCheckB.ID},
} }
// Linking posture check to policy should trigger update account peers and send peer update // Linking posture check to policy should trigger update account peers and send peer update
@@ -226,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -238,7 +243,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture checks should update account peers and send peer update // Updating linked posture checks should update account peers and send peer update
t.Run("updating linked to posture check with peers", func(t *testing.T) { t.Run("updating linked to posture check with peers", func(t *testing.T) {
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
@@ -255,7 +260,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -274,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}() }()
policy.SourcePostureChecks = []string{} policy.SourcePostureChecks = []string{}
_, err := manager.SavePolicy(context.Background(), account.Id, userID, policy)
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -293,7 +297,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -303,17 +307,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
} }
}) })
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
policy = Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policyB",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupB"}, Sources: []string{"groupB"},
Destinations: []string{"groupC"}, Destinations: []string{"groupC"},
@@ -321,9 +323,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
SourcePostureChecks: []string{postureCheck.ID}, SourcePostureChecks: []string{postureCheckB.ID},
} })
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err) assert.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})
@@ -332,12 +333,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -354,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
}) })
policy = Policy{
ID: "policyB", _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: xid.New().String(),
Enabled: true, Enabled: true,
Sources: []string{"groupB"}, Sources: []string{"groupB"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@@ -367,10 +367,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
SourcePostureChecks: []string{postureCheck.ID}, SourcePostureChecks: []string{postureCheckB.ID},
} })
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err) assert.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})
@@ -379,12 +377,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -397,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked client posture check to policy where source has peers but destination does not, // Updating linked client posture check to policy where source has peers but destination does not,
// should trigger account peers update and send peer update // should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
policy = Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policyB",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -409,9 +406,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
}, },
SourcePostureChecks: []string{postureCheck.ID}, SourcePostureChecks: []string{postureCheckB.ID},
} })
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err) assert.NoError(t, err)
done := make(chan struct{}) done := make(chan struct{})
@@ -420,7 +416,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
postureCheck.Checks = posture.ChecksDefinition{ postureCheckB.Checks = posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{ ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{ Processes: []posture.Process{
{ {
@@ -429,7 +425,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}, },
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@@ -440,80 +436,120 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}) })
} }
func TestArePostureCheckChangesAffectingPeers(t *testing.T) { func TestArePostureCheckChangesAffectPeers(t *testing.T) {
account := &Account{ manager, err := createManager(t)
Policies: []*Policy{ require.NoError(t, err, "failed to create account manager")
{
ID: "policyA", account, err := initTestPostureChecksAccount(manager)
Rules: []*PolicyRule{ require.NoError(t, err, "failed to init testing account")
{
Enabled: true, groupA := &group.Group{
Sources: []string{"groupA"}, ID: "groupA",
Destinations: []string{"groupA"}, AccountID: account.Id,
}, Peers: []string{"peer1"},
},
SourcePostureChecks: []string{"checkA"},
},
},
Groups: map[string]*group.Group{
"groupA": {
ID: "groupA",
Peers: []string{"peer1"},
},
"groupB": {
ID: "groupB",
Peers: []string{},
},
},
PostureChecks: []*posture.Checks{
{
ID: "checkA",
},
{
ID: "checkB",
},
},
} }
groupB := &group.Group{
ID: "groupB",
AccountID: account.Id,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
postureCheckA := &posture.Checks{
Name: "checkA",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA)
require.NoError(t, err, "failed to save postureCheckA")
postureCheckB := &posture.Checks{
Name: "checkB",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
require.NoError(t, err, "failed to save postureCheckB")
policy := &Policy{
AccountID: account.Id,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{postureCheckA.ID},
}
policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to save policy")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkB", true) result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check does not exist", func(t *testing.T) { t.Run("posture check does not exist", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "unknown", false) result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupB"} policy.Rules[0].Sources = []string{"groupB"}
account.Policies[0].Rules[0].Destinations = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupA"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupA"} policy.Rules[0].Sources = []string{"groupA"}
account.Policies[0].Rules[0].Destinations = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupB"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} groupA.Peers = []string{}
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) require.NoError(t, err, "failed to save groups")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
account.Groups["groupA"].Peers = []string{} policy.Rules[0].Sources = []string{"nonExistentGroup"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) policy.Rules[0].Destinations = []string{"nonExistentGroup"}
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
} }

View File

@@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err return nil, err
} }
if isRouteChangeAffectPeers(account, &newRoute) { if am.isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
@@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err return err
} }
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
@@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) { if am.isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
@@ -417,27 +417,84 @@ func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string,
continue continue
} }
policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) distributionPeers := a.getDistributionGroupsPeers(route)
for _, policy := range policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules { for _, accessGroup := range route.AccessControlGroups {
if !rule.Enabled { policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup})
continue rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
} routesFirewallRules = append(routesFirewallRules, rules...)
distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap)
rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN)
routesFirewallRules = append(routesFirewallRules, rules...)
}
} }
} }
return routesFirewallRules return routesFirewallRules
} }
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
var fwRules []*RouteFirewallRule
for _, policy := range policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN)
fwRules = append(fwRules, rules...)
}
}
return fwRules
}
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
distPeersWithPolicy := make(map[string]struct{})
for _, id := range rule.Sources {
group := a.Groups[id]
if group == nil {
continue
}
for _, pID := range group.Peers {
if pID == peerID {
continue
}
_, distPeer := distributionPeers[pID]
_, valid := validatedPeersMap[pID]
if distPeer && valid {
distPeersWithPolicy[pID] = struct{}{}
}
}
}
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
for pID := range distPeersWithPolicy {
peer := a.Peers[pID]
if peer == nil {
continue
}
distributionGroupPeers = append(distributionGroupPeers, peer)
}
return distributionGroupPeers
}
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
distPeers := make(map[string]struct{})
for _, id := range route.Groups {
group := a.Groups[id]
if group == nil {
continue
}
for _, pID := range group.Peers {
distPeers[pID] = struct{}{}
}
}
return distPeers
}
func getDefaultPermit(route *route.Route) []*RouteFirewallRule { func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
var rules []*RouteFirewallRule var rules []*RouteFirewallRule
@@ -651,6 +708,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
// isRouteChangeAffectPeers checks if a given route affects peers by determining // isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers // if it has a routing peer, distribution, or peer groups that include peers
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
} }

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sort"
"testing" "testing"
"time" "time"
@@ -1214,12 +1215,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
defaultRule := rules[0] defaultRule := rules[0]
newPolicy := defaultRule.Copy() newPolicy := defaultRule.Copy()
newPolicy.ID = xid.New().String()
newPolicy.Name = "peer1 only" newPolicy.Name = "peer1 only"
newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
require.NoError(t, err) require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
@@ -1487,6 +1487,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
peerBIp = "100.65.80.39" peerBIp = "100.65.80.39"
peerCIp = "100.65.254.139" peerCIp = "100.65.254.139"
peerHIp = "100.65.29.55" peerHIp = "100.65.29.55"
peerJIp = "100.65.29.65"
peerKIp = "100.65.29.66"
) )
account := &Account{ account := &Account{
@@ -1542,6 +1544,16 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
IP: net.ParseIP(peerHIp), IP: net.ParseIP(peerHIp),
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
}, },
"peerJ": {
ID: "peerJ",
IP: net.ParseIP(peerJIp),
Status: &nbpeer.PeerStatus{},
},
"peerK": {
ID: "peerK",
IP: net.ParseIP(peerKIp),
Status: &nbpeer.PeerStatus{},
},
}, },
Groups: map[string]*nbgroup.Group{ Groups: map[string]*nbgroup.Group{
"routingPeer1": { "routingPeer1": {
@@ -1568,6 +1580,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
Name: "Route2", Name: "Route2",
Peers: []string{}, Peers: []string{},
}, },
"route4": {
ID: "route4",
Name: "route4",
Peers: []string{},
},
"finance": { "finance": {
ID: "finance", ID: "finance",
Name: "Finance", Name: "Finance",
@@ -1585,6 +1602,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
"peerB", "peerB",
}, },
}, },
"qa": {
ID: "qa",
Name: "QA",
Peers: []string{
"peerJ",
"peerK",
},
},
"restrictQA": {
ID: "restrictQA",
Name: "restrictQA",
Peers: []string{
"peerJ",
},
},
"unrestrictedQA": {
ID: "unrestrictedQA",
Name: "unrestrictedQA",
Peers: []string{
"peerK",
},
},
"contractors": { "contractors": {
ID: "contractors", ID: "contractors",
Name: "Contractors", Name: "Contractors",
@@ -1632,6 +1671,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
Groups: []string{"contractors"}, Groups: []string{"contractors"},
AccessControlGroups: []string{}, AccessControlGroups: []string{},
}, },
"route4": {
ID: "route4",
Network: netip.MustParsePrefix("192.168.10.0/16"),
NetID: "route4",
NetworkType: route.IPv4Network,
PeerGroups: []string{"routingPeer1"},
Description: "Route4",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{"qa"},
AccessControlGroups: []string{"route4"},
},
}, },
Policies: []*Policy{ Policies: []*Policy{
{ {
@@ -1686,6 +1738,49 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
}, },
}, },
}, },
{
ID: "RuleRoute4",
Name: "RuleRoute4",
Enabled: true,
Rules: []*PolicyRule{
{
ID: "RuleRoute4",
Name: "RuleRoute4",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
Ports: []string{"80"},
Sources: []string{
"restrictQA",
},
Destinations: []string{
"route4",
},
},
},
},
{
ID: "RuleRoute5",
Name: "RuleRoute5",
Enabled: true,
Rules: []*PolicyRule{
{
ID: "RuleRoute5",
Name: "RuleRoute5",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
Sources: []string{
"unrestrictedQA",
},
Destinations: []string{
"route4",
},
},
},
},
}, },
} }
@@ -1710,7 +1805,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
t.Run("check peer routes firewall rules", func(t *testing.T) { t.Run("check peer routes firewall rules", func(t *testing.T) {
routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
assert.Len(t, routesFirewallRules, 2) assert.Len(t, routesFirewallRules, 4)
expectedRoutesFirewallRules := []*RouteFirewallRule{ expectedRoutesFirewallRules := []*RouteFirewallRule{
{ {
@@ -1736,12 +1831,32 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
Port: 320, Port: 320,
}, },
} }
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) additionalFirewallRule := []*RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "tcp",
Port: 80,
},
{
SourceRanges: []string{
fmt.Sprintf(AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "all",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
assert.Len(t, routesFirewallRules, 2) assert.Len(t, routesFirewallRules, 2)
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerE is a single routing peer for route 2 and route 3 // peerE is a single routing peer for route 2 and route 3
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
@@ -1770,7 +1885,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
IsDynamic: true, IsDynamic: true,
}, },
} }
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerC is part of route1 distribution groups but should not receive the routes firewall rules // peerC is part of route1 distribution groups but should not receive the routes firewall rules
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
@@ -1779,6 +1894,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
} }
// orderList is a helper function to sort a list of strings
func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule {
for _, rule := range ruleList {
sort.Strings(rule.SourceRanges)
}
return ruleList
}
func TestRouteAccountPeersUpdate(t *testing.T) { func TestRouteAccountPeersUpdate(t *testing.T) {
manager, err := createRouterManager(t) manager, err := createRouterManager(t)
require.NoError(t, err, "failed to create account manager") require.NoError(t, err, "failed to create account manager")

View File

@@ -406,8 +406,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
policy := Policy{ policy := &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -419,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
}, },
}, },
} }
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)

View File

@@ -1162,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
First(&accountDNSSettings, idQueryCondition, accountID) First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found") return nil, status.NewAccountNotFoundError(accountID)
} }
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get dns settings from store")
} }
return &accountDNSSettings.DNSSettings, nil return &accountDNSSettings.DNSSettings, nil
} }
@@ -1243,8 +1244,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
var groups []*nbgroup.Group var groups []*nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
} }
groupsMap := make(map[string]*nbgroup.Group) groupsMap := make(map[string]*nbgroup.Group)
@@ -1295,22 +1296,139 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) var policies []*Policy
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get policies from store")
}
return policies, nil
} }
// GetPolicyByID retrieves a policy by its ID and account ID. // GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID) var policy *Policy
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
First(&policy, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewPolicyNotFoundError(policyID)
}
log.WithContext(ctx).Errorf("failed to get policy from store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get policy from store")
}
return policy, nil
}
func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to create policy in store")
}
return nil
}
// SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
return status.Errorf(status.Internal, "failed to save policy to store")
}
return nil
}
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
return status.Errorf(status.Internal, "failed to delete policy from store")
}
if result.RowsAffected == 0 {
return status.NewPolicyNotFoundError(policyID)
}
return nil
} }
// GetAccountPostureChecks retrieves posture checks for an account. // GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db, lockStrength, accountID) var postureChecks []*posture.Checks
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
}
return postureChecks, nil
} }
// GetPostureChecksByID retrieves posture checks by their ID and account ID. // GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID) var postureCheck *posture.Checks
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPostureChecksNotFoundError(postureChecksID)
}
log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
}
return postureCheck, nil
}
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
var postureChecks []*posture.Checks
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store")
}
postureChecksMap := make(map[string]*posture.Checks)
for _, postureCheck := range postureChecks {
postureChecksMap[postureCheck.ID] = postureCheck
}
return postureChecksMap, nil
}
// SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save posture checks to store")
}
return nil
}
// DeletePostureChecks deletes a posture checks from the database.
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete posture checks from store")
}
if result.RowsAffected == 0 {
return status.NewPostureChecksNotFoundError(postureChecksID)
}
return nil
} }
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
@@ -1380,12 +1498,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account. // GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID) var nsGroups []*nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
}
return nsGroups, nil
} }
// GetNameServerGroupByID retrieves a name server group by its ID and account ID. // GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID) var nsGroup *nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewNameServerGroupNotFoundError(nsGroupID)
}
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
}
return nsGroup, nil
}
// SaveNameServerGroup saves a name server group to the database.
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store")
}
return nil
}
// DeleteNameServerGroup deletes a name server group from the database.
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete name server group from store")
}
if result.RowsAffected == 0 {
return status.NewNameServerGroupNotFoundError(nsGroupID)
}
return nil
} }
// getRecords retrieves records from the database based on the account ID. // getRecords retrieves records from the database based on the account ID.
@@ -1420,3 +1581,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
} }
return &record, nil return &record, nil
} }
// SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save dns settings to store")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -1564,3 +1565,489 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) {
}) })
} }
} }
func TestSqlStore_GetPostureChecksByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
postureChecksID string
expectError bool
}{
{
name: "retrieve existing posture checks",
postureChecksID: "csplshq7qv948l48f7t0",
expectError: false,
},
{
name: "retrieve non-existing posture checks",
postureChecksID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty posture checks ID",
postureChecksID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, postureChecks)
} else {
require.NoError(t, err)
require.NotNil(t, postureChecks)
require.Equal(t, tt.postureChecksID, postureChecks.ID)
}
})
}
}
func TestSqlStore_GetPostureChecksByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
postureCheckIDs []string
expectedCount int
}{
{
name: "retrieve existing posture checks by existing IDs",
postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"},
expectedCount: 2,
},
{
name: "empty posture check IDs list",
postureCheckIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing posture check IDs",
postureCheckIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing posture check IDs",
postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs)
require.NoError(t, err)
require.Len(t, groups, tt.expectedCount)
})
}
}
func TestSqlStore_SavePostureChecks(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
postureChecks := &posture.Checks{
ID: "posture-checks-id",
AccountID: accountID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.31.0",
},
OSVersionCheck: &posture.OSVersionCheck{
Ios: &posture.MinVersionCheck{
MinVersion: "13.0.1",
},
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "5.3.3-dev",
},
},
GeoLocationCheck: &posture.GeoLocationCheck{
Locations: []posture.Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
Action: posture.CheckActionAllow,
},
},
}
err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks)
require.NoError(t, err)
savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id")
require.NoError(t, err)
require.Equal(t, savePostureChecks, postureChecks)
}
func TestSqlStore_DeletePostureChecks(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
postureChecksID string
expectError bool
}{
{
name: "delete existing posture checks",
postureChecksID: "csplshq7qv948l48f7t0",
expectError: false,
},
{
name: "delete non-existing posture checks",
postureChecksID: "non-existing-posture-checks-id",
expectError: true,
},
{
name: "delete with empty posture checks ID",
postureChecksID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
} else {
require.NoError(t, err)
group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
require.Error(t, err)
require.Nil(t, group)
}
})
}
}
func TestSqlStore_GetPolicyByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
policyID string
expectError bool
}{
{
name: "retrieve existing policy",
policyID: "cs1tnh0hhcjnqoiuebf0",
expectError: false,
},
{
name: "retrieve non-existing policy checks",
policyID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty policy ID",
policyID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, policy)
} else {
require.NoError(t, err)
require.NotNil(t, policy)
require.Equal(t, tt.policyID, policy.ID)
}
})
}
}
func TestSqlStore_CreatePolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
policy := &Policy{
ID: "policy-id",
AccountID: accountID,
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupC"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err)
savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
require.NoError(t, err)
require.Equal(t, savePolicy, policy)
}
func TestSqlStore_SavePolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
policyID := "cs1tnh0hhcjnqoiuebf0"
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
require.NoError(t, err)
policy.Enabled = false
policy.Description = "policy"
policy.Rules[0].Sources = []string{"group"}
policy.Rules[0].Ports = []string{"80", "443"}
err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err)
savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
require.NoError(t, err)
require.Equal(t, savePolicy, policy)
}
func TestSqlStore_DeletePolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
policyID := "cs1tnh0hhcjnqoiuebf0"
err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID)
require.NoError(t, err)
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
require.Error(t, err)
require.Nil(t, policy)
}
func TestSqlStore_GetDNSSettings(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectError bool
}{
{
name: "retrieve existing account dns settings",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectError: false,
},
{
name: "retrieve non-existing account dns settings",
accountID: "non-existing",
expectError: true,
},
{
name: "retrieve dns settings with empty account ID",
accountID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, dnsSettings)
} else {
require.NoError(t, err)
require.NotNil(t, dnsSettings)
}
})
}
}
func TestSqlStore_SaveDNSSettings(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"}
err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings)
require.NoError(t, err)
saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Equal(t, saveDNSSettings, dnsSettings)
}
func TestSqlStore_GetAccountNameServerGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectedCount int
}{
{
name: "retrieve name server groups by existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
},
{
name: "non-existing account ID",
accountID: "nonexistent",
expectedCount: 0,
},
{
name: "empty account ID",
accountID: "",
expectedCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthShare, tt.accountID)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
})
}
}
func TestSqlStore_GetNameServerByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
nsGroupID string
expectError bool
}{
{
name: "retrieve existing nameserver group",
nsGroupID: "csqdelq7qv97ncu7d9t0",
expectError: false,
},
{
name: "retrieve non-existing nameserver group",
nsGroupID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty nameserver group ID",
nsGroupID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, tt.nsGroupID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, nsGroup)
} else {
require.NoError(t, err)
require.NotNil(t, nsGroup)
require.Equal(t, tt.nsGroupID, nsGroup.ID)
}
})
}
}
func TestSqlStore_SaveNameServerGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nsGroup := &nbdns.NameServerGroup{
ID: "ns-group-id",
AccountID: accountID,
Name: "NS Group",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: 1,
Port: 53,
},
},
Groups: []string{"groupA"},
Primary: true,
Enabled: true,
SearchDomainsEnabled: false,
}
err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nsGroup)
require.NoError(t, err)
saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroup.ID)
require.NoError(t, err)
require.Equal(t, saveNSGroup, nsGroup)
}
func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nsGroupID := "csqdelq7qv97ncu7d9t0"
err = store.DeleteNameServerGroup(context.Background(), LockingStrengthShare, accountID, nsGroupID)
require.NoError(t, err)
nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroupID)
require.Error(t, err)
require.Nil(t, nsGroup)
}

View File

@@ -139,3 +139,18 @@ func NewGetAccountError(err error) error {
func NewGroupNotFoundError(groupID string) error { func NewGroupNotFoundError(groupID string) error {
return Errorf(NotFound, "group: %s not found", groupID) return Errorf(NotFound, "group: %s not found", groupID)
} }
// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks
func NewPostureChecksNotFoundError(postureChecksID string) error {
return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
}
// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy
func NewPolicyNotFoundError(policyID string) error {
return Errorf(NotFound, "policy: %s not found", policyID)
}
// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group
func NewNameServerGroupNotFoundError(nsGroupID string) error {
return Errorf(NotFound, "nameserver group: %s not found", nsGroupID)
}

View File

@@ -59,6 +59,7 @@ type Store interface {
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
@@ -80,11 +81,17 @@ type Store interface {
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
@@ -110,6 +117,8 @@ type Store interface {
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error

View File

@@ -34,4 +34,7 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003'
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

View File

@@ -1279,8 +1279,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
policy := Policy{ policy := &Policy{
ID: "policy",
Enabled: true, Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
@@ -1292,7 +1291,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}, },
}, },
} }
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)

View File

@@ -140,7 +140,7 @@ type Client struct {
instanceURL *RelayAddr instanceURL *RelayAddr
muInstanceURL sync.Mutex muInstanceURL sync.Mutex
onDisconnectListener func() onDisconnectListener func(string)
onConnectedListener func() onConnectedListener func()
listenerMutex sync.Mutex listenerMutex sync.Mutex
} }
@@ -233,7 +233,7 @@ func (c *Client) ServerInstanceURL() (string, error) {
} }
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. // SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func()) { func (c *Client) SetOnDisconnectListener(fn func(string)) {
c.listenerMutex.Lock() c.listenerMutex.Lock()
defer c.listenerMutex.Unlock() defer c.listenerMutex.Unlock()
c.onDisconnectListener = fn c.onDisconnectListener = fn
@@ -554,7 +554,7 @@ func (c *Client) notifyDisconnected() {
if c.onDisconnectListener == nil { if c.onDisconnectListener == nil {
return return
} }
go c.onDisconnectListener() go c.onDisconnectListener(c.connectionURL)
} }
func (c *Client) notifyConnected() { func (c *Client) notifyConnected() {

View File

@@ -551,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
} }
disconnected := make(chan struct{}) disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func() { relayClient.SetOnDisconnectListener(func(_ string) {
log.Infof("client disconnected") log.Infof("client disconnected")
close(disconnected) close(disconnected)
}) })

View File

@@ -4,65 +4,120 @@ import (
"context" "context"
"time" "time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( var (
reconnectingTimeout = 5 * time.Second reconnectingTimeout = 60 * time.Second
) )
// Guard manage the reconnection tries to the Relay server in case of disconnection event. // Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct { type Guard struct {
ctx context.Context // OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance.
relayClient *Client OnNewRelayClient chan *Client
serverPicker *ServerPicker
} }
// NewGuard creates a new guard for the relay client. // NewGuard creates a new guard for the relay client.
func NewGuard(context context.Context, relayClient *Client) *Guard { func NewGuard(sp *ServerPicker) *Guard {
g := &Guard{ g := &Guard{
ctx: context, OnNewRelayClient: make(chan *Client, 1),
relayClient: relayClient, serverPicker: sp,
} }
return g return g
} }
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection // StartReconnectTrys is called when the relay client is disconnected from the relay server.
// It attempts to reconnect to the relay server. The function first tries a quick reconnect
// to the same server that was used before, if the server URL is still valid. If the quick
// reconnect fails, it starts a ticker to periodically attempt server picking until it
// succeeds or the context is done.
//
// Parameters:
// - ctx: The context to control the lifecycle of the reconnection attempts.
// - relayClient: The relay client instance that was disconnected.
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
func (g *Guard) OnDisconnected() { func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
if g.quickReconnect() { if relayClient == nil {
goto RETRY
}
if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) {
return return
} }
ticker := time.NewTicker(reconnectingTimeout) RETRY:
ticker := exponentTicker(ctx)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
err := g.relayClient.Connect() if err := g.retry(ctx); err != nil {
if err != nil { log.Errorf("failed to pick new Relay server: %s", err)
log.Errorf("failed to reconnect to relay server: %s", err)
continue continue
} }
return return
case <-g.ctx.Done(): case <-ctx.Done():
return return
} }
} }
} }
func (g *Guard) quickReconnect() bool { func (g *Guard) retry(ctx context.Context) error {
ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond) log.Infof("try to pick up a new Relay server")
relayClient, err := g.serverPicker.PickServer(ctx)
if err != nil {
return err
}
// prevent to work with a deprecated Relay client instance
g.drainRelayClientChan()
g.OnNewRelayClient <- relayClient
return nil
}
func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool {
ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond)
defer cancel() defer cancel()
<-ctx.Done() <-ctx.Done()
if g.ctx.Err() != nil { if parentCtx.Err() != nil {
return false return false
} }
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
if err := g.relayClient.Connect(); err != nil { if err := rc.Connect(); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err) log.Errorf("failed to reconnect to relay server: %s", err)
return false return false
} }
return true return true
} }
func (g *Guard) drainRelayClientChan() {
select {
case <-g.OnNewRelayClient:
default:
}
}
func (g *Guard) isServerURLStillValid(rc *Client) bool {
for _, url := range g.serverPicker.ServerURLs.Load().([]string) {
if url == rc.connectionURL {
return true
}
}
return false
}
func exponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 2 * time.Second,
Multiplier: 2,
MaxInterval: reconnectingTimeout,
Clock: backoff.SystemClock,
}, ctx)
return backoff.NewTicker(bo)
}

View File

@@ -57,12 +57,15 @@ type ManagerService interface {
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any // relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it. // unused relay connection and close it.
type Manager struct { type Manager struct {
ctx context.Context ctx context.Context
serverURLs []string peerID string
peerID string running bool
tokenStore *relayAuth.TokenStore tokenStore *relayAuth.TokenStore
serverPicker *ServerPicker
relayClient *Client relayClient *Client
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
relayClientMu sync.Mutex
reconnectGuard *Guard reconnectGuard *Guard
relayClients map[string]*RelayTrack relayClients map[string]*RelayTrack
@@ -76,48 +79,54 @@ type Manager struct {
// NewManager creates a new manager instance. // NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve. // The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager { func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
return &Manager{ tokenStore := &relayAuth.TokenStore{}
ctx: ctx,
serverURLs: serverURLs, m := &Manager{
peerID: peerID, ctx: ctx,
tokenStore: &relayAuth.TokenStore{}, peerID: peerID,
tokenStore: tokenStore,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
},
relayClients: make(map[string]*RelayTrack), relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List), onDisconnectedListeners: make(map[string]*list.List),
} }
m.serverPicker.ServerURLs.Store(serverURLs)
m.reconnectGuard = NewGuard(m.serverPicker)
return m
} }
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for // Serve starts the manager, attempting to establish a connection with the relay server.
// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection. // If the connection fails, it will keep trying to reconnect in the background.
// Additionally, it starts a cleanup loop to remove unused relay connections.
// The manager will automatically reconnect to the relay server in case of disconnection.
func (m *Manager) Serve() error { func (m *Manager) Serve() error {
if m.relayClient != nil { if m.running {
return fmt.Errorf("manager already serving") return fmt.Errorf("manager already serving")
} }
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) m.running = true
log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load())
sp := ServerPicker{ client, err := m.serverPicker.PickServer(m.ctx)
TokenStore: m.tokenStore,
PeerID: m.peerID,
}
client, err := sp.PickServer(m.ctx, m.serverURLs)
if err != nil { if err != nil {
return err go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
} else {
m.storeClient(client)
} }
m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient) go m.listenGuardEvent(m.ctx)
m.relayClient.SetOnConnectedListener(m.onServerConnected) go m.startCleanupLoop()
m.relayClient.SetOnDisconnectListener(func() { return err
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
} }
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new // established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
if m.relayClient == nil { if m.relayClient == nil {
return nil, ErrRelayClientNotConnected return nil, ErrRelayClientNotConnected
} }
@@ -146,6 +155,9 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
// Ready returns true if the home Relay client is connected to the relay server. // Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool { func (m *Manager) Ready() bool {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
if m.relayClient == nil { if m.relayClient == nil {
return false return false
} }
@@ -159,6 +171,13 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed. // closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
if m.relayClient == nil {
return ErrRelayClientNotConnected
}
foreign, err := m.isForeignServer(serverAddress) foreign, err := m.isForeignServer(serverAddress)
if err != nil { if err != nil {
return err return err
@@ -177,6 +196,9 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication. // lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) { func (m *Manager) RelayInstanceAddress() (string, error) {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
if m.relayClient == nil { if m.relayClient == nil {
return "", ErrRelayClientNotConnected return "", ErrRelayClientNotConnected
} }
@@ -185,13 +207,18 @@ func (m *Manager) RelayInstanceAddress() (string, error) {
// ServerURLs returns the addresses of the relay servers. // ServerURLs returns the addresses of the relay servers.
func (m *Manager) ServerURLs() []string { func (m *Manager) ServerURLs() []string {
return m.serverURLs return m.serverPicker.ServerURLs.Load().([]string)
} }
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with // HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
// Relay service. // Relay service.
func (m *Manager) HasRelayAddress() bool { func (m *Manager) HasRelayAddress() bool {
return len(m.serverURLs) > 0 return len(m.serverPicker.ServerURLs.Load().([]string)) > 0
}
func (m *Manager) UpdateServerURLs(serverURLs []string) {
log.Infof("update relay server URLs: %v", serverURLs)
m.serverPicker.ServerURLs.Store(serverURLs)
} }
// UpdateToken updates the token in the token store. // UpdateToken updates the token in the token store.
@@ -245,9 +272,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
return nil, err return nil, err
} }
// if connection closed then delete the relay client from the list // if connection closed then delete the relay client from the list
relayClient.SetOnDisconnectListener(func() { relayClient.SetOnDisconnectListener(m.onServerDisconnected)
m.onServerDisconnected(serverAddress)
})
rt.relayClient = relayClient rt.relayClient = relayClient
rt.Unlock() rt.Unlock()
@@ -265,14 +290,37 @@ func (m *Manager) onServerConnected() {
go m.onReconnectedListenerFn() go m.onReconnectedListenerFn()
} }
// onServerDisconnected start to reconnection for home server only
func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) onServerDisconnected(serverAddress string) {
m.relayClientMu.Lock()
if serverAddress == m.relayClient.connectionURL { if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.OnDisconnected() go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient)
} }
m.relayClientMu.Unlock()
m.notifyOnDisconnectListeners(serverAddress) m.notifyOnDisconnectListeners(serverAddress)
} }
func (m *Manager) listenGuardEvent(ctx context.Context) {
for {
select {
case rc := <-m.reconnectGuard.OnNewRelayClient:
m.storeClient(rc)
case <-ctx.Done():
return
}
}
}
func (m *Manager) storeClient(client *Client) {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
m.relayClient = client
m.relayClient.SetOnConnectedListener(m.onServerConnected)
m.relayClient.SetOnDisconnectListener(m.onServerDisconnected)
}
func (m *Manager) isForeignServer(address string) (bool, error) { func (m *Manager) isForeignServer(address string) (bool, error) {
rAddr, err := m.relayClient.ServerInstanceURL() rAddr, err := m.relayClient.ServerInstanceURL()
if err != nil { if err != nil {
@@ -282,22 +330,16 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
} }
func (m *Manager) startCleanupLoop() { func (m *Manager) startCleanupLoop() {
if m.ctx.Err() != nil {
return
}
ticker := time.NewTicker(relayCleanupInterval) ticker := time.NewTicker(relayCleanupInterval)
go func() { defer ticker.Stop()
defer ticker.Stop() for {
for { select {
select { case <-m.ctx.Done():
case <-m.ctx.Done(): return
return case <-ticker.C:
case <-ticker.C: m.cleanUpUnusedRelays()
m.cleanUpUnusedRelays()
}
} }
}() }
} }
func (m *Manager) cleanUpUnusedRelays() { func (m *Manager) cleanUpUnusedRelays() {

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"sync/atomic"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -12,10 +13,13 @@ import (
) )
const ( const (
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7 maxConcurrentServers = 7
) )
var (
connectionTimeout = 30 * time.Second
)
type connResult struct { type connResult struct {
RelayClient *Client RelayClient *Client
Url string Url string
@@ -24,20 +28,22 @@ type connResult struct {
type ServerPicker struct { type ServerPicker struct {
TokenStore *auth.TokenStore TokenStore *auth.TokenStore
ServerURLs atomic.Value
PeerID string PeerID string
} }
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) { func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel() defer cancel()
totalServers := len(urls) totalServers := len(sp.ServerURLs.Load().([]string))
connResultChan := make(chan connResult, totalServers) connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1) successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers) concurrentLimiter := make(chan struct{}, maxConcurrentServers)
for _, url := range urls { log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
for _, url := range sp.ServerURLs.Load().([]string) {
// todo check if we have a successful connection so we do not need to connect to other servers // todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{} concurrentLimiter <- struct{}{}
go func(url string) { go func(url string) {
@@ -78,7 +84,7 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ { for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan cr := <-resultChan
if cr.Err != nil { if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue continue
} }
log.Infof("connected to Relay server: %s", cr.Url) log.Infof("connected to Relay server: %s", cr.Url)

View File

@@ -4,19 +4,23 @@ import (
"context" "context"
"errors" "errors"
"testing" "testing"
"time"
) )
func TestServerPicker_UnavailableServers(t *testing.T) { func TestServerPicker_UnavailableServers(t *testing.T) {
connectionTimeout = 5 * time.Second
sp := ServerPicker{ sp := ServerPicker{
TokenStore: nil, TokenStore: nil,
PeerID: "test", PeerID: "test",
} }
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
defer cancel() defer cancel()
go func() { go func() {
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"}) _, err := sp.PickServer(ctx)
if err == nil { if err == nil {
t.Error(err) t.Error(err)
} }

31
util/net/conn.go Normal file
View File

@@ -0,0 +1,31 @@
//go:build !ios
package net
import (
"net"
log "github.com/sirupsen/logrus"
)
// Conn wraps a net.Conn to override the Close method
type Conn struct {
net.Conn
ID ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
func (c *Conn) Close() error {
err := c.Conn.Close()
dialerCloseHooksMutex.RLock()
defer dialerCloseHooksMutex.RUnlock()
for _, hook := range dialerCloseHooks {
if err := hook(c.ID, &c.Conn); err != nil {
log.Errorf("Error executing dialer close hook: %v", err)
}
}
return err
}

58
util/net/dial.go Normal file
View File

@@ -0,0 +1,58 @@
//go:build !ios
package net
import (
"fmt"
"net"
log "github.com/sirupsen/logrus"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
}
return tcpConn, nil
}

View File

@@ -1,25 +0,0 @@
package net
import (
"syscall"
log "github.com/sirupsen/logrus"
)
func (d *Dialer) init() {
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
f := androidProtectSocket
androidProtectSocketLock.Unlock()
if f == nil {
return
}
ok := f(int32(fd))
if !ok {
log.Errorf("failed to protect socket: %d", fd)
}
})
return err
}
}

View File

@@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address) return d.DialContext(context.Background(), network, address)
} }
// Conn wraps a net.Conn to override the Close method
type Conn struct {
net.Conn
ID ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
func (c *Conn) Close() error {
err := c.Conn.Close()
dialerCloseHooksMutex.RLock()
defer dialerCloseHooksMutex.RUnlock()
for _, hook := range dialerCloseHooks {
if err := hook(c.ID, &c.Conn); err != nil {
log.Errorf("Error executing dialer close hook: %v", err)
}
}
return err
}
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
host, _, err := net.SplitHostPort(address) host, _, err := net.SplitHostPort(address)
if err != nil { if err != nil {
@@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
return result.ErrorOrNil() return result.ErrorOrNil()
} }
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
}
return tcpConn, nil
}

View File

@@ -0,0 +1,5 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = ControlProtectSocket
}

View File

@@ -7,6 +7,6 @@ import "syscall"
// init configures the net.Dialer Control function to set the fwmark on the socket // init configures the net.Dialer Control function to set the fwmark on the socket
func (d *Dialer) init() { func (d *Dialer) init() {
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
return SetRawSocketMark(c) return setRawSocketMark(c)
} }
} }

View File

@@ -3,4 +3,5 @@
package net package net
func (d *Dialer) init() { func (d *Dialer) init() {
// implemented on Linux and Android only
} }

29
util/net/env.go Normal file
View File

@@ -0,0 +1,29 @@
package net
import (
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const (
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
)
func CustomRoutingDisabled() bool {
if netstack.IsEnabled() {
return true
}
return os.Getenv(envDisableCustomRouting) == "true"
}
func SkipSocketMark() bool {
if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" {
log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark)
return true
}
return false
}

37
util/net/listen.go Normal file
View File

@@ -0,0 +1,37 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
packetConn := conn.(*PacketConn)
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
if !ok {
if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
}
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
}

View File

@@ -1,26 +0,0 @@
package net
import (
"syscall"
log "github.com/sirupsen/logrus"
)
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
f := androidProtectSocket
androidProtectSocketLock.Unlock()
if f == nil {
return
}
ok := f(int32(fd))
if !ok {
log.Errorf("failed to protect listener socket: %d", fd)
}
})
return err
}
}

View File

@@ -0,0 +1,6 @@
package net
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = ControlProtectSocket
}

View File

@@ -9,6 +9,6 @@ import (
// init configures the net.ListenerConfig Control function to set the fwmark on the socket // init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() { func (l *ListenerConfig) init() {
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
return SetRawSocketMark(c) return setRawSocketMark(c)
} }
} }

View File

@@ -3,4 +3,5 @@
package net package net
func (l *ListenerConfig) init() { func (l *ListenerConfig) init() {
// implemented on Linux and Android only
} }

View File

@@ -8,7 +8,6 @@ import (
"net" "net"
"sync" "sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
return err return err
} }
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
packetConn := conn.(*PacketConn)
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
if !ok {
if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
}
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
}

View File

@@ -2,9 +2,6 @@ package net
import ( import (
"net" "net"
"os"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -16,8 +13,6 @@ const (
PreroutingFwmarkRedirected = 0x1BD01 PreroutingFwmarkRedirected = 0x1BD01
PreroutingFwmarkMasquerade = 0x1BD11 PreroutingFwmarkMasquerade = 0x1BD11
PreroutingFwmarkMasqueradeReturn = 0x1BD12 PreroutingFwmarkMasqueradeReturn = 0x1BD12
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
) )
// ConnectionID provides a globally unique identifier for network connections. // ConnectionID provides a globally unique identifier for network connections.
@@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error
func GenerateConnID() ConnectionID { func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString()) return ConnectionID(uuid.NewString())
} }
func CustomRoutingDisabled() bool {
if netstack.IsEnabled() {
return true
}
return os.Getenv(envDisableCustomRouting) == "true"
}

View File

@@ -4,29 +4,42 @@ package net
import ( import (
"fmt" "fmt"
"os"
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
// SetSocketMark sets the SO_MARK option on the given socket connection // SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error { func SetSocketMark(conn syscall.Conn) error {
if isSocketMarkDisabled() {
return nil
}
sysconn, err := conn.SyscallConn() sysconn, err := conn.SyscallConn()
if err != nil { if err != nil {
return fmt.Errorf("get raw conn: %w", err) return fmt.Errorf("get raw conn: %w", err)
} }
return SetRawSocketMark(sysconn) return setRawSocketMark(sysconn)
} }
func SetRawSocketMark(conn syscall.RawConn) error { // SetSocketOpt sets the SO_MARK option on the given file descriptor
func SetSocketOpt(fd int) error {
if isSocketMarkDisabled() {
return nil
}
return setSocketOptInt(fd)
}
func setRawSocketMark(conn syscall.RawConn) error {
var setErr error var setErr error
err := conn.Control(func(fd uintptr) { err := conn.Control(func(fd uintptr) {
setErr = SetSocketOpt(int(fd)) if isSocketMarkDisabled() {
return
}
setErr = setSocketOptInt(int(fd))
}) })
if err != nil { if err != nil {
return fmt.Errorf("control: %w", err) return fmt.Errorf("control: %w", err)
@@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error {
return nil return nil
} }
func SetSocketOpt(fd int) error { func setSocketOptInt(fd int) error {
if CustomRoutingDisabled() {
log.Infof("Custom routing is disabled, skipping SO_MARK")
return nil
}
// Check for the new environment variable
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
return nil
}
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
} }
func isSocketMarkDisabled() bool {
if CustomRoutingDisabled() {
log.Infof("Custom routing is disabled, skipping SO_MARK")
return true
}
if SkipSocketMark() {
return true
}
return false
}

View File

@@ -1,14 +1,42 @@
package net package net
import "sync" import (
"fmt"
"sync"
"syscall"
)
var ( var (
androidProtectSocketLock sync.Mutex androidProtectSocketLock sync.Mutex
androidProtectSocket func(fd int32) bool androidProtectSocket func(fd int32) bool
) )
func SetAndroidProtectSocketFn(f func(fd int32) bool) { func SetAndroidProtectSocketFn(fn func(fd int32) bool) {
androidProtectSocketLock.Lock() androidProtectSocketLock.Lock()
androidProtectSocket = f androidProtectSocket = fn
androidProtectSocketLock.Unlock() androidProtectSocketLock.Unlock()
} }
// ControlProtectSocket is a Control function that sets the fwmark on the socket
func ControlProtectSocket(_, _ string, c syscall.RawConn) error {
var aErr error
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
defer androidProtectSocketLock.Unlock()
if androidProtectSocket == nil {
aErr = fmt.Errorf("socket protection function not set")
return
}
if !androidProtectSocket(int32(fd)) {
aErr = fmt.Errorf("failed to protect socket via Android")
}
})
if err != nil {
return err
}
return aErr
}