Compare commits

...

50 Commits

Author SHA1 Message Date
bcmmbaga
8c44187900 Merge branch 'refs/heads/test/add-stopwatch' into deploy/peer-performance 2024-05-06 14:29:22 +03:00
Pascal Fischer
0dc050b354 update sqlite operations 2024-05-06 13:25:46 +02:00
Pascal Fischer
69eb6f17c6 remove stacktrace 2024-05-03 17:30:26 +02:00
bcmmbaga
ccbe2e0100 Merge branch 'refs/heads/migrate-with-json-serializer' into deploy/peer-performance 2024-05-03 14:11:23 +03:00
bcmmbaga
5b042eee74 migrate null blob values 2024-05-03 14:07:54 +03:00
Pascal Fischer
125c418d85 add debug to GetAccount sqlite store 2024-04-30 13:52:19 +02:00
Pascal Fischer
9f607029c0 add debug to GetAccount sqlite store 2024-04-30 13:35:04 +02:00
bcmmbaga
4df1d556c8 fix tests 2024-04-30 13:35:49 +03:00
bcmmbaga
acaf315152 Merge branch 'refs/heads/migrate-with-json-serializer' into deploy/peer-performance 2024-04-30 13:32:48 +03:00
bcmmbaga
81b6f7dc56 Merge remote-tracking branch 'refs/remotes/origin/test/add-stopwatch' into deploy/peer-performance 2024-04-30 13:32:37 +03:00
bcmmbaga
cb9a8e9fa4 Add tests 2024-04-30 13:27:14 +03:00
Pascal Fischer
f68d9ce042 fix linter 2024-04-30 12:03:59 +02:00
Pascal Fischer
1f182411c6 update store 2024-04-30 11:43:48 +02:00
Pascal Fischer
68e971cb82 Merge remote-tracking branch 'origin/test/add-stopwatch' into test/add-stopwatch 2024-04-30 11:29:08 +02:00
Pascal Fischer
f4a464fb69 update account manager mock 2024-04-30 11:28:57 +02:00
bcmmbaga
7aff0e828b Refactor 2024-04-30 10:07:40 +03:00
bcmmbaga
90ebf0017b remove duplicate index 2024-04-29 23:21:00 +03:00
bcmmbaga
bf8c6b95de run net ip migration 2024-04-29 23:20:43 +03:00
bcmmbaga
8415064758 migrate net ip field from blob to json 2024-04-29 23:20:22 +03:00
Pascal Fischer
809f6cdbc7 add stacktrace for get network map 2024-04-29 18:56:58 +02:00
Viktor Liu
e435e39158 Fix route selection IDs (#1890) 2024-04-29 18:43:14 +02:00
Maycon Santos
fd26e989e3 Check if channel exist before sending network map (#1894)
Check for connection channel before calculating and sending the network map
2024-04-29 18:31:52 +02:00
Pascal Fischer
7474876d6a add stacktrace for get network map 2024-04-29 18:12:22 +02:00
Pascal Fischer
2095e35217 remove stopwatches + logs 2024-04-29 16:26:44 +02:00
Pascal Fischer
02537e5536 remove stopwatches 2024-04-29 14:57:57 +02:00
bcmmbaga
f99477d3db serialize net.IP as json 2024-04-29 15:45:46 +03:00
Pascal Fischer
3085383729 introduce RWLocks 2024-04-29 14:03:33 +02:00
Pascal Fischer
07bdf977d0 update db access 2024-04-26 18:05:41 +02:00
Pascal Fischer
4a147006d2 add table 2024-04-26 17:45:10 +02:00
Viktor Liu
4424162bce Add client debug features (#1884)
* Add status anonymization
* Add OS/arch to the status command
* Use human-friendly last-update status messages
* Add debug bundle command to collect (anonymized) logs
* Add debug log level command
* And debug for a certain time span command
2024-04-26 17:20:10 +02:00
Pascal Fischer
480192028e try improving Sync method 2024-04-26 17:18:24 +02:00
Viktor Liu
54b045d9ca Replaces powershell with the route command and cache route lookups on windows (#1880) 2024-04-26 16:37:27 +02:00
Bethuel Mmbaga
71c6437bab add content type before writing header (#1887) 2024-04-25 21:20:24 +02:00
Pascal Fischer
027c0e5c63 update 2024-04-25 15:47:32 +02:00
Pascal Fischer
fbb1181b64 update 2024-04-25 15:44:06 +02:00
Pascal Fischer
e50ed290d9 update 2024-04-25 15:17:57 +02:00
Pascal Fischer
8f8cfcbc20 add method timing logs 2024-04-25 14:15:36 +02:00
pascal-fischer
7b254cb966 add methods to manage rosenpass settings for iOS (#1879) 2024-04-23 19:26:03 +02:00
pascal-fischer
8f3a0f2c38 Add retry to IdP cache lookup (#1882) 2024-04-23 19:23:43 +02:00
pascal-fischer
1f33e2e003 Support exit nodes on iOS (#1878) 2024-04-23 19:12:16 +02:00
pascal-fischer
1e6addaa65 Add account locks to getAccountWithAuthorizationClaims method (#1847) 2024-04-23 19:09:58 +02:00
Viktor Liu
f51dc13f8c Add route selection functionality for CLI and GUI (#1865) 2024-04-23 14:42:53 +02:00
dependabot[bot]
3477108ce7 Bump golang.org/x/net from 0.20.0 to 0.23.0 (#1867)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.20.0 to 0.23.0.
- [Commits](https://github.com/golang/net/compare/v0.20.0...v0.23.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-23 12:48:25 +02:00
Maycon Santos
012e624296 Fix DNS not found query response (#1877)
for local queries, we should return NXDOMAIN instead of NOERROR

Also, updated gomobile for Android and iOS builds
2024-04-23 10:20:09 +02:00
Maycon Santos
4c5e987e02 Add support for GUI app to display error (#1844) 2024-04-22 11:57:38 +02:00
Maycon Santos
a80c8b0176 Redeem invite only when incoming user was invited (#1861)
checks for users with pending invite status in the cache that already logged in and refresh the cache
2024-04-22 11:10:27 +02:00
Misha Bragin
9e01155d2e Add new intro image 2024-04-22 11:00:52 +02:00
Maycon Santos
3c3111ad01 Copy client binary to a directory in path (#1842) 2024-04-22 10:14:07 +02:00
Misha Bragin
b74078fd95 Use a better way to insert data in batches (#1874) 2024-04-20 22:04:20 +02:00
Viktor Liu
77488ad11a Migrate serializer:gob fields to serializer:json (#1855) 2024-04-18 18:14:21 +02:00
75 changed files with 4499 additions and 612 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta ignore_words_list: erro,clienta,hastable,
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:

View File

@@ -38,7 +38,7 @@ jobs:
- name: Setup NDK - name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620" run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init - name: gomobile init
run: gomobile init run: gomobile init
- name: build android netbird lib - name: build android netbird lib
@@ -56,10 +56,10 @@ jobs:
with: with:
go-version: "1.21.x" go-version: "1.21.x"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init - name: gomobile init
run: gomobile init run: gomobile init
- name: build iOS netbird lib - name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env: env:
CGO_ENABLED: 0 CGO_ENABLED: 0

View File

@@ -44,7 +44,8 @@
### Open-Source Network Security in a Single Platform ### Open-Source Network Security in a Single Platform
![image](https://github.com/netbirdio/netbird/assets/700848/c0d7bae4-3301-499a-bb4e-5e4a225bf35f)
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### Key features ### Key features

View File

@@ -1,5 +1,5 @@
FROM alpine:3.18.5 FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/go/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /go/bin/netbird COPY netbird /usr/local/bin/netbird

View File

@@ -0,0 +1,212 @@
package anonymize
import (
"crypto/rand"
"fmt"
"math/big"
"net"
"net/netip"
"net/url"
"regexp"
"slices"
"strings"
)
type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string
currentAnonIPv4 netip.Addr
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 192.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
}
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
return &Anonymizer{
ipAnonymizer: map[netip.Addr]netip.Addr{},
domainAnonymizer: map[string]string{},
currentAnonIPv4: startIPv4,
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
}
}
func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
if ip.IsLoopback() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsInterfaceLocalMulticast() ||
ip.IsPrivate() ||
ip.IsUnspecified() ||
ip.IsMulticast() ||
isWellKnown(ip) ||
a.isInAnonymizedRange(ip) {
return ip
}
if _, ok := a.ipAnonymizer[ip]; !ok {
if ip.Is4() {
a.ipAnonymizer[ip] = a.currentAnonIPv4
a.currentAnonIPv4 = a.currentAnonIPv4.Next()
} else {
a.ipAnonymizer[ip] = a.currentAnonIPv6
a.currentAnonIPv6 = a.currentAnonIPv6.Next()
}
}
return a.ipAnonymizer[ip]
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
return true
} else if !ip.Is4() && ip.Compare(a.startAnonIPv6) >= 0 && ip.Compare(a.currentAnonIPv6) <= 0 {
return true
}
return false
}
func (a *Anonymizer) AnonymizeIPString(ip string) string {
addr, err := netip.ParseAddr(ip)
if err != nil {
return ip
}
return a.AnonymizeIP(addr).String()
}
func (a *Anonymizer) AnonymizeDomain(domain string) string {
if strings.HasSuffix(domain, "netbird.io") ||
strings.HasSuffix(domain, "netbird.selfhosted") ||
strings.HasSuffix(domain, "netbird.cloud") ||
strings.HasSuffix(domain, "netbird.stage") ||
strings.HasSuffix(domain, ".domain") {
return domain
}
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase
}
return strings.Replace(domain, baseDomain, anonymized, 1)
}
func (a *Anonymizer) AnonymizeURI(uri string) string {
u, err := url.Parse(uri)
if err != nil {
return uri
}
var anonymizedHost string
if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Opaque)
}
u.Opaque = anonymizedHost
} else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Host)
}
u.Host = anonymizedHost
}
return u.String()
}
func (a *Anonymizer) AnonymizeString(str string) string {
ipv4Regex := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
ipv6Regex := regexp.MustCompile(`\b([0-9a-fA-F:]+:+[0-9a-fA-F]{0,4})(?:%[0-9a-zA-Z]+)?(?:\/[0-9]{1,3})?(?::[0-9]{1,5})?\b`)
str = ipv4Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
str = ipv6Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
for domain, anonDomain := range a.domainAnonymizer {
str = strings.ReplaceAll(str, domain, anonDomain)
}
str = a.AnonymizeSchemeURI(str)
str = a.AnonymizeDNSLogLine(str)
return str
}
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, ".domain") {
return match
}
randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
}
return match
})
}
func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
"2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS IPv6
"1.1.1.1", "1.0.0.1", // Cloudflare DNS IPv4
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
}
if slices.Contains(wellKnown, addr.String()) {
return true
}
cgnatRangeStart := netip.AddrFrom4([4]byte{100, 64, 0, 0})
cgnatRange := netip.PrefixFrom(cgnatRangeStart, 10)
return cgnatRange.Contains(addr)
}
func generateRandomString(length int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
continue
}
result[i] = letters[num.Int64()]
}
return string(result)
}

View File

@@ -0,0 +1,223 @@
package anonymize_test
import (
"net/netip"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
)
func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("100::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct {
name string
ip string
expect string
}{
{"Well known", "8.8.8.8", "8.8.8.8"},
{"First Public IPv4", "1.2.3.4", "198.51.100.0"},
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Second Public IPv6", "a::b", "100::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ip := netip.MustParseAddr(tc.ip)
anonymizedIP := anonymizer.AnonymizeIP(ip)
if anonymizedIP.String() != tc.expect {
t.Errorf("%s: expected %s, got %s", tc.name, tc.expect, anonymizedIP)
}
})
}
}
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
}
func TestAnonymizeDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
domain string
expectPattern string
shouldAnonymize bool
}{
{
"General Domain",
"example.com",
`^anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Subdomain",
"sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Protected Domain",
"netbird.io",
`^netbird\.io$`,
false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDomain(tc.domain)
if tc.shouldAnonymize {
assert.Regexp(t, tc.expectPattern, result, "The anonymized domain should match the expected pattern")
assert.NotContains(t, result, tc.domain, "The original domain should not be present in the result")
} else {
assert.Equal(t, tc.domain, result, "Protected domains should not be anonymized")
}
})
}
}
func TestAnonymizeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
uri string
regex string
}{
{
"HTTP URI with Port",
"http://example.com:80/path",
`^http://anon-[a-zA-Z0-9]+\.domain:80/path$`,
},
{
"HTTP URI without Port",
"http://example.com/path",
`^http://anon-[a-zA-Z0-9]+\.domain/path$`,
},
{
"Opaque URI with Port",
"stun:example.com:80?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain:80\?transport=udp$`,
},
{
"Opaque URI without Port",
"stun:example.com?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain\?transport=udp$`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeURI(tc.uri)
assert.Regexp(t, regexp.MustCompile(tc.regex), result, "URI should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizeSchemeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
input string
expect string
}{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeSchemeURI(tc.input)
assert.Regexp(t, tc.expect, result, "The anonymized output should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizString_MemorizedDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "This is a test string including the domain example.com which should be anonymized."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, anonymizedDomain, "The domain should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, domain, "The original domain should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the string")
}
func TestAnonymizeString_DoubleURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "Check out our site at https://example.com for more info."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, "https://"+anonymizedDomain, "The URI should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, "https://example.com", "The original URI should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the URI")
}
func TestAnonymizeString_IPAddresses(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
tests := []struct {
name string
input string
expect string
}{
{
name: "IPv4 Address",
input: "Error occurred at IP 122.138.1.1",
expect: "Error occurred at IP 198.51.100.0",
},
{
name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 100::",
},
{
name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [100::]:8080",
},
{
name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeString(tc.input)
assert.Equal(t, tc.expect, result, "IP addresses should be anonymized correctly")
})
}
}

248
client/cmd/debug.go Normal file
View File

@@ -0,0 +1,248 @@
package cmd
import (
"context"
"fmt"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the Netbird daemon.",
}
var debugBundleCmd = &cobra.Command{
Use: "bundle",
Example: " netbird debug bundle",
Short: "Create a debug bundle",
Long: "Generates a compressed archive of the daemon's logs and status for debugging purposes.",
RunE: debugBundle,
}
var logCmd = &cobra.Command{
Use: "log",
Short: "Manage logging for the Netbird daemon",
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`,
}
var logLevelCmd = &cobra.Command{
Use: "level <level>",
Short: "Set the logging level for this session",
Long: `Sets the logging level for the current session. This setting is temporary and will revert to the default on daemon restart.
Available log levels are:
panic: for panic level, highest level of severity
fatal: for fatal level errors that cause the program to exit
error: for error conditions
warn: for warning conditions
info: for informational messages
debug: for debug-level messages
trace: for trace-level messages, which include more fine-grained information than debug`,
Args: cobra.ExactArgs(1),
RunE: setLogLevel,
}
var forCmd = &cobra.Command{
Use: "for <time>",
Short: "Run debug logs for a specified duration and create a debug bundle",
Long: `Sets the logging level to trace, runs for the specified duration, and then generates a debug bundle.`,
Example: " netbird debug for 5m",
Args: cobra.ExactArgs(1),
RunE: runForDuration,
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
level := parseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: level,
})
if err != nil {
return fmt.Errorf("failed to set log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level set successfully to", args[0])
return nil
}
func parseLogLevel(level string) proto.LogLevel {
switch strings.ToLower(level) {
case "panic":
return proto.LogLevel_PANIC
case "fatal":
return proto.LogLevel_FATAL
case "error":
return proto.LogLevel_ERROR
case "warn":
return proto.LogLevel_WARN
case "info":
return proto.LogLevel_INFO
case "debug":
return proto.LogLevel_DEBUG
case "trace":
return proto.LogLevel_TRACE
default:
return proto.LogLevel_UNKNOWN
}
}
func runForDuration(cmd *cobra.Command, args []string) error {
duration, err := time.ParseDuration(args[0])
if err != nil {
return fmt.Errorf("invalid duration format: %v", err)
}
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to trace: %v", status.Convert(err).Message())
}
cmd.Println("Log level set to trace.")
time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
// TODO reset log level
time.Sleep(1 * time.Second)
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
startTime := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
elapsed := time.Since(startTime)
if elapsed >= duration {
return
}
remaining := duration - elapsed
cmd.Printf("\rRemaining time: %s", formatDuration(remaining))
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}
func formatDuration(d time.Duration) string {
d = d.Round(time.Second)
h := d / time.Hour
d %= time.Hour
m := d / time.Minute
d %= time.Minute
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}

View File

@@ -65,6 +65,7 @@ var (
serviceName string serviceName string
autoConnectDisabled bool autoConnectDisabled bool
extraIFaceBlackList []string extraIFaceBlackList []string
anonymizeFlag bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
Short: "", Short: "",
@@ -119,6 +120,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.AddCommand(serviceCmd) rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd) rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd) rootCmd.AddCommand(downCmd)
@@ -126,8 +129,20 @@ func init() {
rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd) rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
routesCmd.AddCommand(routesListCmd)
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+ `Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+ `You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
@@ -335,3 +350,14 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true return true
} }
func getClient(ctx context.Context) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
return conn, nil
}

131
client/cmd/route.go Normal file
View File

@@ -0,0 +1,131 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
selectedStatus := "Not Selected"
if route.GetSelected() {
selectedStatus = "Selected"
}
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
return nil
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -24,7 +24,7 @@ var (
) )
var sshCmd = &cobra.Command{ var sshCmd = &cobra.Command{
Use: "ssh", Use: "ssh [user@]host",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 { if len(args) < 1 {
return errors.New("requires a host argument") return errors.New("requires a host argument")
@@ -94,7 +94,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
if err != nil { if err != nil {
cmd.Printf("Error: %v\n", err) cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"You can verify the connection by running:\n\n" + "\nYou can verify the connection by running:\n\n" +
" netbird status\n\n") " netbird status\n\n")
return err return err
} }

View File

@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"runtime"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -14,6 +16,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
@@ -144,9 +147,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed initializing log %v", err) return fmt.Errorf("failed initializing log %v", err)
} }
ctx := internal.CtxInitState(context.Background()) ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx, cmd) resp, err := getStatus(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -191,7 +194,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) { func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -200,7 +203,7 @@ func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse,
} }
defer conn.Close() defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true}) resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
} }
@@ -283,6 +286,11 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
} }
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview return overview
} }
@@ -525,8 +533,16 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf( summary := fmt.Sprintf(
"Daemon version: %s\n"+ "OS: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+ "CLI version: %s\n"+
"Management: %s\n"+ "Management: %s\n"+
"Signal: %s\n"+ "Signal: %s\n"+
@@ -538,6 +554,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
"Quantum resistance: %s\n"+ "Quantum resistance: %s\n"+
"Routes: %s\n"+ "Routes: %s\n"+
"Peers count: %s\n", "Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion, overview.DaemonVersion,
version.NetbirdVersion(), version.NetbirdVersion(),
managementConnString, managementConnString,
@@ -593,15 +610,6 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
if peerState.IceCandidateEndpoint.Remote != "" { if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
} }
lastStatusUpdate := "-"
if !peerState.LastStatusUpdate.IsZero() {
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
}
lastWireGuardHandshake := "-"
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
}
rosenpassEnabledStatus := "false" rosenpassEnabledStatus := "false"
if rosenpassEnabled { if rosenpassEnabled {
@@ -652,8 +660,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
remoteICE, remoteICE,
localICEEndpoint, localICEEndpoint,
remoteICEEndpoint, remoteICEEndpoint,
lastStatusUpdate, timeAgo(peerState.LastStatusUpdate),
lastWireGuardHandshake, timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived), toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent), toIEC(peerState.TransferSent),
rosenpassEnabledStatus, rosenpassEnabledStatus,
@@ -722,3 +730,129 @@ func countEnabled(dnsServers []nsServerGroupStateOutput) int {
} }
return count return count
} }
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@@ -3,6 +3,8 @@ package cmd
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"runtime"
"testing" "testing"
"time" "time"
@@ -487,9 +489,15 @@ dnsServers:
} }
func TestParsingToDetail(t *testing.T) { func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview) detail := parseToFullDetailSummary(overview)
expectedDetail := expectedDetail := fmt.Sprintf(
`Peers detail: `Peers detail:
peer-1.awesome-domain.com: peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101 NetBird IP: 192.168.178.101
@@ -500,8 +508,8 @@ func TestParsingToDetail(t *testing.T) {
Direct: true Direct: true
ICE candidate (Local/Remote): -/- ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/- ICE candidate endpoints (Local/Remote): -/-
Last connection update: 2001-01-01 01:01:01 Last connection update: %s
Last WireGuard handshake: 2001-01-01 01:01:02 Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B Transfer status (received/sent) 200 B/100 B
Quantum resistance: false Quantum resistance: false
Routes: 10.1.0.0/24 Routes: 10.1.0.0/24
@@ -516,15 +524,16 @@ func TestParsingToDetail(t *testing.T) {
Direct: false Direct: false
ICE candidate (Local/Remote): relay/prflx ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002 ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Last connection update: 2002-02-02 02:02:02 Last connection update: %s
Last WireGuard handshake: 2002-02-02 02:02:03 Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false Quantum resistance: false
Routes: - Routes: -
Latency: 10ms Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1 Daemon version: 0.14.1
CLI version: development CLI version: %s
Management: Connected to my-awesome-management.com:443 Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443 Signal: Connected to my-awesome-signal.com:443
Relays: Relays:
@@ -539,7 +548,7 @@ Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Routes: 10.10.0.0/24 Routes: 10.10.0.0/24
Peers count: 2/2 Connected Peers count: 2/2 Connected
` `, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
assert.Equal(t, expectedDetail, detail) assert.Equal(t, expectedDetail, detail)
} }
@@ -547,8 +556,8 @@ Peers count: 2/2 Connected
func TestParsingToShortVersion(t *testing.T) { func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false) shortVersion := parseGeneralSummary(overview, false, false, false)
expectedString := expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
`Daemon version: 0.14.1 Daemon version: 0.14.1
CLI version: development CLI version: development
Management: Connected Management: Connected
Signal: Connected Signal: Connected
@@ -572,3 +581,31 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP) assert.Equal(t, "192.168.178.123\n", parsedIP)
} }
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -31,7 +31,7 @@ import (
// RunClient with main logic. // RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error { func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil) return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil)
} }
// RunClientWithProbes runs the client's main logic with probes attached // RunClientWithProbes runs the client's main logic with probes attached
@@ -43,8 +43,9 @@ func RunClientWithProbes(
signalProbe *Probe, signalProbe *Probe,
relayProbe *Probe, relayProbe *Probe,
wgProbe *Probe, wgProbe *Probe,
engineChan chan<- *Engine,
) error { ) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe) return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
} }
// RunClientMobile with main logic on mobile system // RunClientMobile with main logic on mobile system
@@ -66,7 +67,7 @@ func RunClientMobile(
HostDNSAddresses: dnsAddresses, HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener, DnsReadyListener: dnsReadyListener,
} }
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil) return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
} }
func RunClientiOS( func RunClientiOS(
@@ -82,7 +83,7 @@ func RunClientiOS(
NetworkChangeListener: networkChangeListener, NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager, DnsManager: dnsManager,
} }
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil) return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
} }
func runClient( func runClient(
@@ -94,6 +95,7 @@ func runClient(
signalProbe *Probe, signalProbe *Probe,
relayProbe *Probe, relayProbe *Probe,
wgProbe *Probe, wgProbe *Probe,
engineChan chan<- *Engine,
) error { ) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -243,6 +245,9 @@ func runClient(
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
if engineChan != nil {
engineChan <- engine
}
log.Print("Netbird engine started, my IP is: ", peerConfig.Address) log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected) state.Set(StatusConnected)
@@ -252,6 +257,10 @@ func runClient(
backOff.Reset() backOff.Reset()
if engineChan != nil {
engineChan <- nil
}
err = engine.Stop() err = engine.Stop()
if err != nil { if err != nil {
log.Errorf("failed stopping engine %v", err) log.Errorf("failed stopping engine %v", err)

View File

@@ -31,6 +31,8 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
response := d.lookupRecord(r) response := d.lookupRecord(r)
if response != nil { if response != nil {
replyMessage.Answer = append(replyMessage.Answer, response) replyMessage.Answer = append(replyMessage.Answer, response)
} else {
replyMessage.Rcode = dns.RcodeNameError
} }
err := w.WriteMsg(replyMessage) err := w.WriteMsg(replyMessage)

View File

@@ -111,6 +111,9 @@ type Engine struct {
// TURNs is a list of STUN servers used by ICE // TURNs is a list of STUN servers used by ICE
TURNs []*stun.URI TURNs []*stun.URI
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes map[string][]*route.Route
cancel context.CancelFunc cancel context.CancelFunc
ctx context.Context ctx context.Context
@@ -216,6 +219,8 @@ func (e *Engine) Stop() error {
return err return err
} }
e.clientRoutes = nil
// very ugly but we want to remove peers from the WireGuard interface first before removing interface. // very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.CLose() asynchronously // Removing peers happens in the conn.CLose() asynchronously
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
@@ -695,11 +700,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if protoRoutes == nil { if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{} protoRoutes = []*mgmProto.Route{}
} }
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil { if err != nil {
log.Errorf("failed to update routes, err: %v", err) log.Errorf("failed to update clientRoutes, err: %v", err)
} }
e.clientRoutes = clientRoutes
protoDNSConfig := networkMap.GetDNSConfig() protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil { if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{} protoDNSConfig = &mgmProto.DNSConfig{}
@@ -1229,6 +1237,28 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
} }
} }
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() map[string][]*route.Route {
return e.clientRoutes
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
routes := make(map[string][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
if i := strings.LastIndex(id, "-"); i != -1 {
id = id[:i]
}
routes[id] = v
}
return routes
}
// GetRouteManager returns the route manager
func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@@ -577,10 +578,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{} }{}
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
input.inputSerial = updateSerial input.inputSerial = updateSerial
input.inputRoutes = newRoutes input.inputRoutes = newRoutes
return testCase.inputErr return nil, nil, testCase.inputErr
}, },
} }
@@ -597,8 +598,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
err = engine.updateNetworkMap(testCase.networkMap) err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match") assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match") assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match") assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
}) })
} }
} }
@@ -742,8 +743,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
return nil return nil, nil, nil
}, },
} }

View File

@@ -0,0 +1 @@
package internal

View File

@@ -3,6 +3,7 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"time" "time"
@@ -155,7 +156,11 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro
if currScore != 0 && currScore < chosenScore+0.1 { if currScore != 0 && currScore < chosenScore+0.1 {
return currID return currID
} else { } else {
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network) var peer string
if route := c.routes[chosen]; route != nil {
peer = route.Peer
}
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, peer, chosenScore, c.network)
} }
} }
@@ -215,7 +220,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
func (c *clientNetwork) removeRouteFromPeerAndSystem() error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil { if c.chosenRoute != nil {
if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("remove route %s from system, err: %v", c.network, err) return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
} }
@@ -256,7 +261,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
} }
} else { } else {
// otherwise add the route to the system // otherwise add the route to the system
if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err) c.network.String(), c.wgInterface.Address().IP.String(), err)
} }
@@ -344,3 +349,15 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
} }
} }
} }
func (c *clientNetwork) getAsInterface() *net.Interface {
intf, err := net.InterfaceByName(c.wgInterface.Name())
if err != nil {
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
intf = &net.Interface{
Name: c.wgInterface.Name(),
}
}
return intf
}

View File

@@ -14,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -28,7 +29,9 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
TriggerSelection(map[string][]*route.Route)
GetRouteSelector() *routeselector.RouteSelector
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error EnableServerRouter(firewall firewall.Manager) error
@@ -41,6 +44,7 @@ type DefaultManager struct {
stop context.CancelFunc stop context.CancelFunc
mux sync.Mutex mux sync.Mutex
clientNetworks map[string]*clientNetwork clientNetworks map[string]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter serverRouter serverRouter
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
@@ -54,6 +58,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
clientNetworks: make(map[string]*clientNetwork), clientNetworks: make(map[string]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
@@ -117,28 +122,29 @@ func (m *DefaultManager) Stop() {
} }
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not updating routes as context is closed") log.Infof("not updating routes as context is closed")
return m.ctx.Err() return nil, nil, m.ctx.Err()
default: default:
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes) newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
m.updateClientNetworks(updateSerial, newClientRoutesIDMap) filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap) m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
if m.serverRouter != nil { if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap) err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil { if err != nil {
return fmt.Errorf("update routes: %w", err) return nil, nil, fmt.Errorf("update routes: %w", err)
} }
} }
return nil return newServerRoutesMap, newClientRoutesIDMap, nil
} }
} }
@@ -152,16 +158,51 @@ func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.initialRouteRanges() return m.notifier.initialRouteRanges()
} }
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { // GetRouteSelector returns the route selector
// removing routes that do not exist as per the update from the Management service. func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
return m.routeSelector
}
// GetClientRoutes returns the client routes
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork {
return m.clientNetworks
}
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
m.mux.Lock()
defer m.mux.Unlock()
networks = m.routeSelector.FilterSelected(networks)
m.stopObsoleteClients(networks)
for id, routes := range networks {
if _, found := m.clientNetworks[id]; found {
// don't touch existing client network watchers
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}
}
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) {
for id, client := range m.clientNetworks { for id, client := range m.clientNetworks {
_, found := networks[id] if _, ok := networks[id]; !ok {
if !found { log.Debugf("Stopping client network watcher, %s", id)
log.Debugf("stopping client network watcher, %s", id)
client.stop() client.stop()
delete(m.clientNetworks, id) delete(m.clientNetworks, id)
} }
} }
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
m.stopObsoleteClients(networks)
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
@@ -178,7 +219,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
} }
} }
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) { func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
newClientRoutesIDMap := make(map[string][]*route.Route) newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route) newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool) ownNetworkIDs := make(map[string]bool)
@@ -210,7 +251,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
} }
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifiesRoutes(initialRoutes) _, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0) rs := make([]*route.Route, 0)
for _, routes := range crMap { for _, routes := range crMap {
rs = append(rs, routes...) rs = append(rs, routes...)
@@ -221,7 +262,7 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
func isPrefixSupported(prefix netip.Prefix) bool { func isPrefixSupported(prefix netip.Prefix) bool {
if !nbnet.CustomRoutingDisabled() { if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS { switch runtime.GOOS {
case "linux", "windows", "darwin": case "linux", "windows", "darwin", "ios":
return true return true
} }
} }

View File

@@ -428,11 +428,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
} }
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) _, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes") require.NoError(t, err, "should update routes with init routes")
} }
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) _, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes") require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected expectedWatchers := testCase.clientNetworkWatchersExpected

View File

@@ -7,14 +7,17 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
type MockManager struct { type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
StopFunc func() TriggerSelectionFunc func(map[string][]*route.Route)
GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func()
} }
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
@@ -27,11 +30,25 @@ func (m *MockManager) InitialRouteRange() []string {
} }
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface // UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
if m.UpdateRoutesFunc != nil { if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes) return m.UpdateRoutesFunc(updateSerial, newRoutes)
} }
return fmt.Errorf("method UpdateRoutes is not implemented") return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
}
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) {
if m.TriggerSelectionFunc != nil {
m.TriggerSelectionFunc(networks)
}
}
// GetRouteSelector mock implementation of GetRouteSelector from Manager interface
func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
if m.GetRouteSelectorFunc != nil {
return m.GetRouteSelectorFunc()
}
return nil
} }
// Start mock implementation of Start from Manager interface // Start mock implementation of Start from Manager interface

View File

@@ -5,6 +5,7 @@ package routemanager
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync" "sync"
@@ -17,7 +18,7 @@ import (
type ref struct { type ref struct {
count int count int
nexthop netip.Addr nexthop netip.Addr
intf string intf *net.Interface
} }
type RouteManager struct { type RouteManager struct {
@@ -30,8 +31,8 @@ type RouteManager struct {
mutex sync.Mutex mutex sync.Mutex
} }
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap // TODO: read initial routing table into refCountMap

View File

@@ -60,17 +60,13 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
return nil return nil
} }
var exitIntf string
gatewayHop, intf, err := getNextHop(defaultGateway) gatewayHop, intf, err := getNextHop(defaultGateway)
if err != nil && !errors.Is(err, ErrRouteNotFound) { if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
} }
if intf != nil {
exitIntf = intf.Name
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) return addToRouteTable(gatewayPrefix, gatewayHop, intf)
} }
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
@@ -84,7 +80,7 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
return netip.Addr{}, nil, ErrRouteNotFound return netip.Addr{}, nil, ErrRouteNotFound
} }
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil { if gateway == nil {
if preferredSrc == nil { if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound return netip.Addr{}, nil, ErrRouteNotFound
@@ -153,12 +149,7 @@ func isSubRange(prefix netip.Prefix) (bool, error) {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values. // If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf( func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
prefix netip.Prefix,
vpnIntf *iface.WGIface,
initialNextHop netip.Addr,
initialIntf *net.Interface,
) (netip.Addr, string, error) {
addr := prefix.Addr() addr := prefix.Addr()
switch { switch {
case addr.IsLoopback(), case addr.IsLoopback(),
@@ -168,39 +159,34 @@ func addRouteToNonVPNIntf(
addr.IsUnspecified(), addr.IsUnspecified(),
addr.IsMulticast(): addr.IsMulticast():
return netip.Addr{}, "", ErrRouteNotAllowed return netip.Addr{}, nil, ErrRouteNotAllowed
} }
// Determine the exit interface and next hop for the prefix, so we can add a specific route // Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := getNextHop(addr) nexthop, intf, err := getNextHop(addr)
if err != nil { if err != nil {
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
} }
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop exitNextHop := nexthop
var exitIntf string exitIntf := intf
if intf != nil {
exitIntf = intf.Name
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok { if !ok {
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
} }
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop exitNextHop = initialNextHop
if initialIntf != nil { exitIntf = initialIntf
exitIntf = initialIntf.Name
}
} }
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
} }
return exitNextHop, exitIntf, nil return exitNextHop, exitIntf, nil
@@ -208,7 +194,7 @@ func addRouteToNonVPNIntf(
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route // in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf string) error { func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 { if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err return err
@@ -250,7 +236,7 @@ func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
} }
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table // addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf string) error { func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix) ok, err := existsInRouteTable(prefix)
if err != nil { if err != nil {
return fmt.Errorf("exists in route table: %w", err) return fmt.Errorf("exists in route table: %w", err)
@@ -277,7 +263,7 @@ func addNonExistingRoute(prefix netip.Prefix, intf string) error {
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, // genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes // it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 { if prefix == defaultv4 {
var result *multierror.Error var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
@@ -343,7 +329,7 @@ func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []n
} }
*routeManager = NewRouteManager( *routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, string, error) { func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr() addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4 nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() { if addr.Is6() {

View File

@@ -24,10 +24,10 @@ func enableIPForwarding() error {
return nil return nil
} }
func addVPNRoute(netip.Prefix, string) error { func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil return nil
} }
func removeVPNRoute(netip.Prefix, string) error { func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil return nil
} }

View File

@@ -27,15 +27,15 @@ func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager) return cleanupRoutingWithRouteManager(routeManager)
} }
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("add", prefix, nexthop, intf) return routeCmd("add", prefix, nexthop, intf)
} }
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("delete", prefix, nexthop, intf) return routeCmd("delete", prefix, nexthop, intf)
} }
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
inet := "-inet" inet := "-inet"
network := prefix.String() network := prefix.String()
if prefix.IsSingleIP() { if prefix.IsSingleIP() {
@@ -46,15 +46,15 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf strin
// Special case for IPv6 split default route, pointing to the wg interface fails // Special case for IPv6 split default route, pointing to the wg interface fails
// TODO: Remove once we have IPv6 support on the interface // TODO: Remove once we have IPv6 support on the interface
if prefix.Bits() == 1 { if prefix.Bits() == 1 {
intf = "lo0" intf = &net.Interface{Name: "lo0"}
} }
} }
args := []string{"-n", action, inet, network} args := []string{"-n", action, inet, network}
if nexthop.IsValid() { if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String()) args = append(args, nexthop.Unmap().String())
} else if intf != "" { } else if intf != nil {
args = append(args, "-interface", intf) args = append(args, "-interface", intf.Name)
} }
if err := retryRouteCmd(args); err != nil { if err := retryRouteCmd(args); err != nil {

View File

@@ -33,7 +33,7 @@ func init() {
func TestConcurrentRoutes(t *testing.T) { func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0") baseIP := netip.MustParseAddr("192.0.2.0")
intf := "lo0" intf := &net.Interface{Name: "lo0"}
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {

View File

@@ -24,10 +24,10 @@ func enableIPForwarding() error {
return nil return nil
} }
func addVPNRoute(netip.Prefix, string) error { func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil return nil
} }
func removeVPNRoute(netip.Prefix, string) error { func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil return nil
} }

View File

@@ -46,9 +46,6 @@ var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified // originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int var originalSysctl map[string]int
// determines whether to use the legacy routing setup
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
// sysctlFailed is used as an indicator to emit a warning when default routes are configured // sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool var sysctlFailed bool
@@ -62,6 +59,20 @@ type ruleParams struct {
description string description string
} }
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
}
// setIsLegacy sets the legacy routing setup
func setIsLegacy(b bool) {
if b {
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
} else {
os.Unsetenv("NB_USE_LEGACY_ROUTING")
}
}
func getSetupRules() []ruleParams { func getSetupRules() []ruleParams {
return []ruleParams{ return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
@@ -82,7 +93,7 @@ func getSetupRules() []ruleParams {
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy { if isLegacy() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
} }
@@ -111,7 +122,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
if err := addRule(rule); err != nil { if err := addRule(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) { if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
isLegacy = true setIsLegacy(true)
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
} }
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
@@ -125,7 +136,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
// It systematically removes the three rules and any associated routing table entries to ensure a clean state. // It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process. // The function uses error aggregation to report any errors encountered during the cleanup process.
func cleanupRouting() error { func cleanupRouting() error {
if isLegacy { if isLegacy() {
return cleanupRoutingWithRouteManager(routeManager) return cleanupRoutingWithRouteManager(routeManager)
} }
@@ -154,16 +165,16 @@ func cleanupRouting() error {
return result.ErrorOrNil() return result.ErrorOrNil()
} }
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
} }
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
} }
func addVPNRoute(prefix netip.Prefix, intf string) error { func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy { if isLegacy() {
return genericAddVPNRoute(prefix, intf) return genericAddVPNRoute(prefix, intf)
} }
@@ -185,8 +196,8 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
return nil return nil
} }
func removeVPNRoute(prefix netip.Prefix, intf string) error { func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy { if isLegacy() {
return genericRemoveVPNRoute(prefix, intf) return genericRemoveVPNRoute(prefix, intf)
} }
@@ -244,7 +255,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
} }
// addRoute adds a route to a specific routing table identified by tableID. // addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
route := &netlink.Route{ route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE, Scope: netlink.SCOPE_UNIVERSE,
Table: tableID, Table: tableID,
@@ -304,7 +315,10 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet, Dst: ipNet,
} }
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { if err := netlink.RouteDel(route); err != nil &&
!errors.Is(err, syscall.ESRCH) &&
!errors.Is(err, syscall.ENOENT) &&
!errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink remove unreachable route: %w", err) return fmt.Errorf("netlink remove unreachable route: %w", err)
} }
@@ -313,7 +327,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
} }
// removeRoute removes a route from a specific routing table identified by tableID. // removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String()) _, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil { if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err) return fmt.Errorf("parse prefix %s: %w", prefix, err)
@@ -467,20 +481,22 @@ func removeRule(params ruleParams) error {
} }
// addNextHop adds the gateway and device to the route. // addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
if addr.IsValid() { if intf != nil {
route.Gw = addr.AsSlice() route.LinkIndex = intf.Index
if intf == "" {
intf = addr.Zone()
}
} }
if intf != "" { if addr.IsValid() {
link, err := netlink.LinkByName(intf) route.Gw = addr.AsSlice()
if err != nil {
return fmt.Errorf("set interface %s: %w", intf, err) // if zone is set, it means the gateway is a link-local address, so we set the link index
if addr.Zone() != "" && intf == nil {
link, err := netlink.LinkByName(addr.Zone())
if err != nil {
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
}
route.LinkIndex = link.Attrs().Index
} }
route.LinkIndex = link.Attrs().Index
} }
return nil return nil

View File

@@ -3,6 +3,7 @@
package routemanager package routemanager
import ( import (
"net"
"net/netip" "net/netip"
"runtime" "runtime"
@@ -14,10 +15,10 @@ func enableIPForwarding() error {
return nil return nil
} }
func addVPNRoute(prefix netip.Prefix, intf string) error { func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericAddVPNRoute(prefix, intf) return genericAddVPNRoute(prefix, intf)
} }
func removeVPNRoute(prefix netip.Prefix, intf string) error { func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericRemoveVPNRoute(prefix, intf) return genericRemoveVPNRoute(prefix, intf)
} }

View File

@@ -50,6 +50,8 @@ func TestAddRemoveRoutes(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey() peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
@@ -67,7 +69,11 @@ func TestAddRemoveRoutes(t *testing.T) {
assert.NoError(t, cleanupRouting()) assert.NoError(t, cleanupRouting())
}) })
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err") require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard { if testCase.shouldRouteToWireguard {
@@ -78,7 +84,7 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix) exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err") require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard { if exists && testCase.shouldRouteToWireguard {
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err") require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
@@ -182,12 +188,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
} }
for n, testCase := range testCases { for n, testCase := range testCases {
var buf bytes.Buffer var buf bytes.Buffer
log.SetOutput(&buf) log.SetOutput(&buf)
defer func() { defer func() {
log.SetOutput(os.Stderr) log.SetOutput(os.Stderr)
}() }()
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey() peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
@@ -200,14 +210,18 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// Prepare the environment // Prepare the environment
if testCase.preExistingPrefix.IsValid() { if testCase.preExistingPrefix.IsValid() {
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) err := addVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route") require.NoError(t, err, "should not return err when adding pre-existing route")
} }
// Add the route // Add the route
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) err = addVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route") require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute { if testCase.shouldAddRoute {
@@ -217,7 +231,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.True(t, ok, "route should exist") require.True(t, ok, "route should exist")
// remove route again if added // remove route again if added
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) err = removeVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err") require.NoError(t, err, "should not return err")
} }
@@ -345,43 +359,47 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, cleanupRouting()) assert.NoError(t, cleanupRouting())
}) })
index, err := net.InterfaceByName(wgIface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
// default route exists in main table and vpn table // default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
require.NoError(t, err, "addVPNRoute should not return err") require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() { t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err") assert.NoError(t, err, "removeVPNRoute should not return err")
}) })
// 10.0.0.0/8 route exists in main table and vpn table // 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
require.NoError(t, err, "addVPNRoute should not return err") require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() { t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err") assert.NoError(t, err, "removeVPNRoute should not return err")
}) })
// 10.10.0.0/24 more specific route exists in vpn table // 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err") require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() { t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err") assert.NoError(t, err, "removeVPNRoute should not return err")
}) })
// 127.0.10.0/24 more specific route exists in vpn table // 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err") require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() { t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err") assert.NoError(t, err, "removeVPNRoute should not return err")
}) })
// unique route in vpn table // unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
require.NoError(t, err, "addVPNRoute should not return err") require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() { t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err") assert.NoError(t, err, "removeVPNRoute should not return err")
}) })
} }

View File

@@ -6,8 +6,12 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"os/exec" "os/exec"
"strconv"
"strings" "strings"
"sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi" "github.com/yusufpapurcu/wmi"
@@ -21,6 +25,10 @@ type Win32_IP4RouteTable struct {
Mask string Mask string
} }
var prefixList []netip.Prefix
var lastUpdate time.Time
var mux = sync.Mutex{}
var routeManager *RouteManager var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
@@ -32,15 +40,23 @@ func cleanupRouting() error {
} }
func getRoutesFromTable() ([]netip.Prefix, error) { func getRoutesFromTable() ([]netip.Prefix, error) {
var routes []Win32_IP4RouteTable mux.Lock()
defer mux.Unlock()
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
return prefixList, nil
}
var routes []Win32_IP4RouteTable
err := wmi.Query(query, &routes) err := wmi.Query(query, &routes)
if err != nil { if err != nil {
return nil, fmt.Errorf("get routes: %w", err) return nil, fmt.Errorf("get routes: %w", err)
} }
var prefixList []netip.Prefix prefixList = nil
for _, route := range routes { for _, route := range routes {
addr, err := netip.ParseAddr(route.Destination) addr, err := netip.ParseAddr(route.Destination)
if err != nil { if err != nil {
@@ -60,54 +76,29 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
prefixList = append(prefixList, routePrefix) prefixList = append(prefixList, routePrefix)
} }
} }
lastUpdate = time.Now()
return prefixList, nil return prefixList, nil
} }
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error { func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
destinationPrefix := prefix.String() args := []string{"add", prefix.String()}
psCmd := "New-NetRoute"
addressFamily := "IPv4"
if prefix.Addr().Is6() {
addressFamily = "IPv6"
}
script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
psCmd, addressFamily, destinationPrefix,
)
if intfIdx != "" {
script = fmt.Sprintf(
`%s -InterfaceIndex %s`, script, intfIdx,
)
} else {
script = fmt.Sprintf(
`%s -InterfaceAlias "%s"`, script, intf,
)
}
if nexthop.IsValid() { if nexthop.IsValid() {
script = fmt.Sprintf( args = append(args, nexthop.Unmap().String())
`%s -NextHop "%s"`, script, nexthop, } else {
) addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
} }
out, err := exec.Command("powershell", "-Command", script).CombinedOutput() if intf != nil {
log.Tracef("PowerShell %s: %s", script, string(out)) args = append(args, "if", strconv.Itoa(intf.Index))
if err != nil {
return fmt.Errorf("PowerShell add route: %w", err)
} }
return nil
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
args := []string{"add", prefix.String(), nexthop.Unmap().String()}
out, err := exec.Command("route", args...).CombinedOutput() out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out) log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil { if err != nil {
return fmt.Errorf("route add: %w", err) return fmt.Errorf("route add: %w", err)
@@ -116,21 +107,20 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
return nil return nil
} }
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
var intfIdx string if nexthop.Zone() != "" && intf == nil {
if nexthop.Zone() != "" { zone, err := strconv.Atoi(nexthop.Zone())
intfIdx = nexthop.Zone() if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
nexthop.WithZone("") nexthop.WithZone("")
} }
// Powershell doesn't support adding routes without an interface but allows to add interface by name
if intf != "" || intfIdx != "" {
return addRoutePowershell(prefix, nexthop, intf, intfIdx)
}
return addRouteCmd(prefix, nexthop, intf) return addRouteCmd(prefix, nexthop, intf)
} }
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
args := []string{"delete", prefix.String()} args := []string{"delete", prefix.String()}
if nexthop.IsValid() { if nexthop.IsValid() {
nexthop.WithZone("") nexthop.WithZone("")
@@ -145,3 +135,7 @@ func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) err
} }
return nil return nil
} }
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}

View File

@@ -0,0 +1,132 @@
package routeselector
import (
"fmt"
"slices"
"strings"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
route "github.com/netbirdio/netbird/route"
)
type RouteSelector struct {
selectedRoutes map[string]struct{}
selectAll bool
}
func NewRouteSelector() *RouteSelector {
return &RouteSelector{
selectedRoutes: map[string]struct{}{},
// default selects all routes
selectAll: true,
}
}
// SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error {
if !appendRoute {
rs.selectedRoutes = map[string]struct{}{}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
rs.selectedRoutes[route] = struct{}{}
}
rs.selectAll = false
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() {
rs.selectAll = true
rs.selectedRoutes = map[string]struct{}{}
}
// DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error {
if rs.selectAll {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{}
}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
delete(rs.selectedRoutes, route)
}
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID string) bool {
if rs.selectAll {
return true
}
_, selected := rs.selectedRoutes[routeID]
return selected
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route {
if rs.selectAll {
return maps.Clone(routes)
}
filtered := map[string][]*route.Route{}
for id, rt := range routes {
netID := id
if i := strings.LastIndex(id, "-"); i != -1 {
netID = id[:i]
}
if rs.IsSelected(netID) {
filtered[id] = rt
}
}
return filtered
}
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}

View File

@@ -0,0 +1,275 @@
package routeselector_test
import (
"slices"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func TestRouteSelector_SelectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
selectRoutes []string
append bool
wantSelected []string
wantError bool
}{
{
name: "Select specific routes, initial all selected",
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes, initial all deselected",
initialSelected: []string{},
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2", "route3"},
wantSelected: []string{"route2", "route3"},
},
{
name: "Select non-existing route",
selectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route1"},
wantError: true,
},
{
name: "Append route with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route1", "route2"},
},
{
name: "Append route without initial selection",
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial some selected",
initialSelected: []string{"route1"},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{"route1", "route2", "route3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.SelectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
deselectRoutes []string
wantSelected []string
wantError bool
}{
{
name: "Deselect specific routes, initial all selected",
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route3"},
},
{
name: "Deselect specific routes, initial all deselected",
initialSelected: []string{},
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Deselect specific routes with initial selection",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route3"},
wantSelected: []string{"route2"},
},
{
name: "Deselect non-existing route",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route2"},
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.DeselectRoutes(tt.deselectRoutes, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectAll(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{},
},
{
name: "Initial some selected",
initialSelected: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.DeselectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_IsSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
assert.True(t, rs.IsSelected("route1"))
assert.True(t, rs.IsSelected("route2"))
assert.False(t, rs.IsSelected("route3"))
assert.False(t, rs.IsSelected("route4"))
}
func TestRouteSelector_FilterSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
routes := map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
"route3-172.16.0.0/12": {},
}
filtered := rs.FilterSelected(routes)
assert.Equal(t, map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
}, filtered)
}

View File

@@ -71,6 +71,42 @@ func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key p.configInput.PreSharedKey = &key
} }
// SetRosenpassEnabled store if rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled read rosenpass enabled from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive store the given permissive and wait for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive read rosenpass permissive from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassPermissive, err
}
// Commit write out the changes into config file // Commit write out the changes into config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput) _, err := internal.UpdateOrCreateConfig(p.configInput)

View File

@@ -1,17 +1,17 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v4.24.3 // protoc v3.12.4
// source: daemon.proto // source: daemon.proto
package proto package proto
import ( import (
_ "github.com/golang/protobuf/protoc-gen-go/descriptor"
duration "github.com/golang/protobuf/ptypes/duration"
timestamp "github.com/golang/protobuf/ptypes/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
_ "google.golang.org/protobuf/types/descriptorpb"
durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect" reflect "reflect"
sync "sync" sync "sync"
) )
@@ -23,6 +23,70 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
type LogLevel int32
const (
LogLevel_UNKNOWN LogLevel = 0
LogLevel_PANIC LogLevel = 1
LogLevel_FATAL LogLevel = 2
LogLevel_ERROR LogLevel = 3
LogLevel_WARN LogLevel = 4
LogLevel_INFO LogLevel = 5
LogLevel_DEBUG LogLevel = 6
LogLevel_TRACE LogLevel = 7
)
// Enum value maps for LogLevel.
var (
LogLevel_name = map[int32]string{
0: "UNKNOWN",
1: "PANIC",
2: "FATAL",
3: "ERROR",
4: "WARN",
5: "INFO",
6: "DEBUG",
7: "TRACE",
}
LogLevel_value = map[string]int32{
"UNKNOWN": 0,
"PANIC": 1,
"FATAL": 2,
"ERROR": 3,
"WARN": 4,
"INFO": 5,
"DEBUG": 6,
"TRACE": 7,
}
)
func (x LogLevel) Enum() *LogLevel {
p := new(LogLevel)
*p = x
return p
}
func (x LogLevel) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (LogLevel) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[0].Descriptor()
}
func (LogLevel) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[0]
}
func (x LogLevel) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use LogLevel.Descriptor instead.
func (LogLevel) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{0}
}
type LoginRequest struct { type LoginRequest struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@@ -766,23 +830,23 @@ type PeerState struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"`
PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"` PubKey string `protobuf:"bytes,2,opt,name=pubKey,proto3" json:"pubKey,omitempty"`
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"` ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"` ConnStatusUpdate *timestamp.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"` Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"` Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"` LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"` RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"` Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"` LocalIceCandidateEndpoint string `protobuf:"bytes,10,opt,name=localIceCandidateEndpoint,proto3" json:"localIceCandidateEndpoint,omitempty"`
RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"` RemoteIceCandidateEndpoint string `protobuf:"bytes,11,opt,name=remoteIceCandidateEndpoint,proto3" json:"remoteIceCandidateEndpoint,omitempty"`
LastWireguardHandshake *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"` LastWireguardHandshake *timestamp.Timestamp `protobuf:"bytes,12,opt,name=lastWireguardHandshake,proto3" json:"lastWireguardHandshake,omitempty"`
BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"`
BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"`
RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"`
Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` Latency *duration.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
} }
func (x *PeerState) Reset() { func (x *PeerState) Reset() {
@@ -838,7 +902,7 @@ func (x *PeerState) GetConnStatus() string {
return "" return ""
} }
func (x *PeerState) GetConnStatusUpdate() *timestamppb.Timestamp { func (x *PeerState) GetConnStatusUpdate() *timestamp.Timestamp {
if x != nil { if x != nil {
return x.ConnStatusUpdate return x.ConnStatusUpdate
} }
@@ -894,7 +958,7 @@ func (x *PeerState) GetRemoteIceCandidateEndpoint() string {
return "" return ""
} }
func (x *PeerState) GetLastWireguardHandshake() *timestamppb.Timestamp { func (x *PeerState) GetLastWireguardHandshake() *timestamp.Timestamp {
if x != nil { if x != nil {
return x.LastWireguardHandshake return x.LastWireguardHandshake
} }
@@ -929,7 +993,7 @@ func (x *PeerState) GetRoutes() []string {
return nil return nil
} }
func (x *PeerState) GetLatency() *durationpb.Duration { func (x *PeerState) GetLatency() *duration.Duration {
if x != nil { if x != nil {
return x.Latency return x.Latency
} }
@@ -1383,6 +1447,442 @@ func (x *FullStatus) GetDnsServers() []*NSGroupState {
return nil return nil
} }
type ListRoutesRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *ListRoutesRequest) Reset() {
*x = ListRoutesRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[19]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListRoutesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListRoutesRequest) ProtoMessage() {}
func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[19]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListRoutesRequest.ProtoReflect.Descriptor instead.
func (*ListRoutesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{19}
}
type ListRoutesResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"`
}
func (x *ListRoutesResponse) Reset() {
*x = ListRoutesResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[20]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListRoutesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListRoutesResponse) ProtoMessage() {}
func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[20]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListRoutesResponse.ProtoReflect.Descriptor instead.
func (*ListRoutesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{20}
}
func (x *ListRoutesResponse) GetRoutes() []*Route {
if x != nil {
return x.Routes
}
return nil
}
type SelectRoutesRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
RouteIDs []string `protobuf:"bytes,1,rep,name=routeIDs,proto3" json:"routeIDs,omitempty"`
Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"`
All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"`
}
func (x *SelectRoutesRequest) Reset() {
*x = SelectRoutesRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[21]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SelectRoutesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SelectRoutesRequest) ProtoMessage() {}
func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[21]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SelectRoutesRequest.ProtoReflect.Descriptor instead.
func (*SelectRoutesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{21}
}
func (x *SelectRoutesRequest) GetRouteIDs() []string {
if x != nil {
return x.RouteIDs
}
return nil
}
func (x *SelectRoutesRequest) GetAppend() bool {
if x != nil {
return x.Append
}
return false
}
func (x *SelectRoutesRequest) GetAll() bool {
if x != nil {
return x.All
}
return false
}
type SelectRoutesResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SelectRoutesResponse) Reset() {
*x = SelectRoutesResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[22]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SelectRoutesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SelectRoutesResponse) ProtoMessage() {}
func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[22]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SelectRoutesResponse.ProtoReflect.Descriptor instead.
func (*SelectRoutesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{22}
}
type Route struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"`
Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"`
Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"`
}
func (x *Route) Reset() {
*x = Route{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[23]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Route) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Route) ProtoMessage() {}
func (x *Route) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[23]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Route.ProtoReflect.Descriptor instead.
func (*Route) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{23}
}
func (x *Route) GetID() string {
if x != nil {
return x.ID
}
return ""
}
func (x *Route) GetNetwork() string {
if x != nil {
return x.Network
}
return ""
}
func (x *Route) GetSelected() bool {
if x != nil {
return x.Selected
}
return false
}
type DebugBundleRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
}
func (x *DebugBundleRequest) Reset() {
*x = DebugBundleRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[24]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DebugBundleRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DebugBundleRequest) ProtoMessage() {}
func (x *DebugBundleRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[24]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DebugBundleRequest.ProtoReflect.Descriptor instead.
func (*DebugBundleRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{24}
}
func (x *DebugBundleRequest) GetAnonymize() bool {
if x != nil {
return x.Anonymize
}
return false
}
func (x *DebugBundleRequest) GetStatus() string {
if x != nil {
return x.Status
}
return ""
}
type DebugBundleResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
}
func (x *DebugBundleResponse) Reset() {
*x = DebugBundleResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[25]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DebugBundleResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DebugBundleResponse) ProtoMessage() {}
func (x *DebugBundleResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[25]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DebugBundleResponse.ProtoReflect.Descriptor instead.
func (*DebugBundleResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{25}
}
func (x *DebugBundleResponse) GetPath() string {
if x != nil {
return x.Path
}
return ""
}
type SetLogLevelRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Level LogLevel `protobuf:"varint,1,opt,name=level,proto3,enum=daemon.LogLevel" json:"level,omitempty"`
}
func (x *SetLogLevelRequest) Reset() {
*x = SetLogLevelRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[26]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetLogLevelRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetLogLevelRequest) ProtoMessage() {}
func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[26]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead.
func (*SetLogLevelRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{26}
}
func (x *SetLogLevelRequest) GetLevel() LogLevel {
if x != nil {
return x.Level
}
return LogLevel_UNKNOWN
}
type SetLogLevelResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SetLogLevelResponse) Reset() {
*x = SetLogLevelResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[27]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetLogLevelResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetLogLevelResponse) ProtoMessage() {}
func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[27]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead.
func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{27}
}
var File_daemon_proto protoreflect.FileDescriptor var File_daemon_proto protoreflect.FileDescriptor
var file_daemon_proto_rawDesc = []byte{ var file_daemon_proto_rawDesc = []byte{
@@ -1601,32 +2101,92 @@ var file_daemon_proto_rawDesc = []byte{
0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65,
0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74,
0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a,
0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22,
0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49,
0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01,
0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c,
0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70,
0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x4d, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x74, 0x65, 0x64, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32,
0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c,
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a,
0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55,
0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49,
0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09,
0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52,
0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a,
0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43,
0x45, 0x10, 0x07, 0x32, 0xee, 0x05, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70,
0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a,
0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c,
0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74,
0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12,
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (
@@ -1641,58 +2201,81 @@ func file_daemon_proto_rawDescGZIP() []byte {
return file_daemon_proto_rawDescData return file_daemon_proto_rawDescData
} }
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 19) var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 28)
var file_daemon_proto_goTypes = []interface{}{ var file_daemon_proto_goTypes = []interface{}{
(*LoginRequest)(nil), // 0: daemon.LoginRequest (LogLevel)(0), // 0: daemon.LogLevel
(*LoginResponse)(nil), // 1: daemon.LoginResponse (*LoginRequest)(nil), // 1: daemon.LoginRequest
(*WaitSSOLoginRequest)(nil), // 2: daemon.WaitSSOLoginRequest (*LoginResponse)(nil), // 2: daemon.LoginResponse
(*WaitSSOLoginResponse)(nil), // 3: daemon.WaitSSOLoginResponse (*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
(*UpRequest)(nil), // 4: daemon.UpRequest (*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
(*UpResponse)(nil), // 5: daemon.UpResponse (*UpRequest)(nil), // 5: daemon.UpRequest
(*StatusRequest)(nil), // 6: daemon.StatusRequest (*UpResponse)(nil), // 6: daemon.UpResponse
(*StatusResponse)(nil), // 7: daemon.StatusResponse (*StatusRequest)(nil), // 7: daemon.StatusRequest
(*DownRequest)(nil), // 8: daemon.DownRequest (*StatusResponse)(nil), // 8: daemon.StatusResponse
(*DownResponse)(nil), // 9: daemon.DownResponse (*DownRequest)(nil), // 9: daemon.DownRequest
(*GetConfigRequest)(nil), // 10: daemon.GetConfigRequest (*DownResponse)(nil), // 10: daemon.DownResponse
(*GetConfigResponse)(nil), // 11: daemon.GetConfigResponse (*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
(*PeerState)(nil), // 12: daemon.PeerState (*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
(*LocalPeerState)(nil), // 13: daemon.LocalPeerState (*PeerState)(nil), // 13: daemon.PeerState
(*SignalState)(nil), // 14: daemon.SignalState (*LocalPeerState)(nil), // 14: daemon.LocalPeerState
(*ManagementState)(nil), // 15: daemon.ManagementState (*SignalState)(nil), // 15: daemon.SignalState
(*RelayState)(nil), // 16: daemon.RelayState (*ManagementState)(nil), // 16: daemon.ManagementState
(*NSGroupState)(nil), // 17: daemon.NSGroupState (*RelayState)(nil), // 17: daemon.RelayState
(*FullStatus)(nil), // 18: daemon.FullStatus (*NSGroupState)(nil), // 18: daemon.NSGroupState
(*timestamppb.Timestamp)(nil), // 19: google.protobuf.Timestamp (*FullStatus)(nil), // 19: daemon.FullStatus
(*durationpb.Duration)(nil), // 20: google.protobuf.Duration (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
(*Route)(nil), // 24: daemon.Route
(*DebugBundleRequest)(nil), // 25: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 26: daemon.DebugBundleResponse
(*SetLogLevelRequest)(nil), // 27: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 28: daemon.SetLogLevelResponse
(*timestamp.Timestamp)(nil), // 29: google.protobuf.Timestamp
(*duration.Duration)(nil), // 30: google.protobuf.Duration
} }
var file_daemon_proto_depIdxs = []int32{ var file_daemon_proto_depIdxs = []int32{
18, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 19, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
19, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp 29, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
19, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp 29, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
20, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration 30, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
15, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 16, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
14, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState 15, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState
13, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 14, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
12, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState 13, // 7: daemon.FullStatus.peers:type_name -> daemon.PeerState
16, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState 17, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState
17, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 18, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
0, // 10: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 24, // 10: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
2, // 11: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest 0, // 11: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
4, // 12: daemon.DaemonService.Up:input_type -> daemon.UpRequest 1, // 12: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
6, // 13: daemon.DaemonService.Status:input_type -> daemon.StatusRequest 3, // 13: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
8, // 14: daemon.DaemonService.Down:input_type -> daemon.DownRequest 5, // 14: daemon.DaemonService.Up:input_type -> daemon.UpRequest
10, // 15: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest 7, // 15: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
1, // 16: daemon.DaemonService.Login:output_type -> daemon.LoginResponse 9, // 16: daemon.DaemonService.Down:input_type -> daemon.DownRequest
3, // 17: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse 11, // 17: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
5, // 18: daemon.DaemonService.Up:output_type -> daemon.UpResponse 20, // 18: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
7, // 19: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 22, // 19: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
9, // 20: daemon.DaemonService.Down:output_type -> daemon.DownResponse 22, // 20: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
11, // 21: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse 25, // 21: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
16, // [16:22] is the sub-list for method output_type 27, // 22: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
10, // [10:16] is the sub-list for method input_type 2, // 23: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // [10:10] is the sub-list for extension type_name 4, // 24: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
10, // [10:10] is the sub-list for extension extendee 6, // 25: daemon.DaemonService.Up:output_type -> daemon.UpResponse
0, // [0:10] is the sub-list for field type_name 8, // 26: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
10, // 27: daemon.DaemonService.Down:output_type -> daemon.DownResponse
12, // 28: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
21, // 29: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
23, // 30: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
23, // 31: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
26, // 32: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
28, // 33: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
23, // [23:34] is the sub-list for method output_type
12, // [12:23] is the sub-list for method input_type
12, // [12:12] is the sub-list for extension type_name
12, // [12:12] is the sub-list for extension extendee
0, // [0:12] is the sub-list for field type_name
} }
func init() { file_daemon_proto_init() } func init() { file_daemon_proto_init() }
@@ -1929,6 +2512,114 @@ func file_daemon_proto_init() {
return nil return nil
} }
} }
file_daemon_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListRoutesRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListRoutesResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelectRoutesRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelectRoutesResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Route); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DebugBundleRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DebugBundleResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetLogLevelRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetLogLevelResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{} file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
type x struct{} type x struct{}
@@ -1936,13 +2627,14 @@ func file_daemon_proto_init() {
File: protoimpl.DescBuilder{ File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_daemon_proto_rawDesc, RawDescriptor: file_daemon_proto_rawDesc,
NumEnums: 0, NumEnums: 1,
NumMessages: 19, NumMessages: 28,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },
GoTypes: file_daemon_proto_goTypes, GoTypes: file_daemon_proto_goTypes,
DependencyIndexes: file_daemon_proto_depIdxs, DependencyIndexes: file_daemon_proto_depIdxs,
EnumInfos: file_daemon_proto_enumTypes,
MessageInfos: file_daemon_proto_msgTypes, MessageInfos: file_daemon_proto_msgTypes,
}.Build() }.Build()
File_daemon_proto = out.File File_daemon_proto = out.File

View File

@@ -27,6 +27,21 @@ service DaemonService {
// GetConfig of the daemon. // GetConfig of the daemon.
rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {} rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {}
// List available network routes
rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {}
// Select specific routes
rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
// Deselect specific routes
rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {}
// DebugBundle creates a debug bundle
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
// SetLogLevel sets the log level of the daemon
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
}; };
message LoginRequest { message LoginRequest {
@@ -195,4 +210,53 @@ message FullStatus {
repeated PeerState peers = 4; repeated PeerState peers = 4;
repeated RelayState relays = 5; repeated RelayState relays = 5;
repeated NSGroupState dns_servers = 6; repeated NSGroupState dns_servers = 6;
}
message ListRoutesRequest {
}
message ListRoutesResponse {
repeated Route routes = 1;
}
message SelectRoutesRequest {
repeated string routeIDs = 1;
bool append = 2;
bool all = 3;
}
message SelectRoutesResponse {
}
message Route {
string ID = 1;
string network = 2;
bool selected = 3;
}
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
}
message DebugBundleResponse {
string path = 1;
}
enum LogLevel {
UNKNOWN = 0;
PANIC = 1;
FATAL = 2;
ERROR = 3;
WARN = 4;
INFO = 5;
DEBUG = 6;
TRACE = 7;
}
message SetLogLevelRequest {
LogLevel level = 1;
}
message SetLogLevelResponse {
} }

View File

@@ -31,6 +31,16 @@ type DaemonServiceClient interface {
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
// GetConfig of the daemon. // GetConfig of the daemon.
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
// List available network routes
ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error)
// Select specific routes
SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
// Deselect specific routes
DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
} }
type daemonServiceClient struct { type daemonServiceClient struct {
@@ -95,6 +105,51 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques
return out, nil return out, nil
} }
func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) {
out := new(ListRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
out := new(SelectRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) {
out := new(SelectRoutesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) {
out := new(DebugBundleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
out := new(SetLogLevelResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service. // DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer // All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility // for forward compatibility
@@ -112,6 +167,16 @@ type DaemonServiceServer interface {
Down(context.Context, *DownRequest) (*DownResponse, error) Down(context.Context, *DownRequest) (*DownResponse, error)
// GetConfig of the daemon. // GetConfig of the daemon.
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
// List available network routes
ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error)
// Select specific routes
SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
// Deselect specific routes
DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
mustEmbedUnimplementedDaemonServiceServer() mustEmbedUnimplementedDaemonServiceServer()
} }
@@ -137,6 +202,21 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do
func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) { func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented")
} }
func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented")
}
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
}
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -258,6 +338,96 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).ListRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ListRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SelectRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SelectRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SelectRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SelectRoutesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DeselectRoutes(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DeselectRoutes",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DebugBundleRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DebugBundle(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DebugBundle",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SetLogLevelRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SetLogLevel(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetLogLevel",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -289,6 +459,26 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetConfig", MethodName: "GetConfig",
Handler: _DaemonService_GetConfig_Handler, Handler: _DaemonService_GetConfig_Handler,
}, },
{
MethodName: "ListRoutes",
Handler: _DaemonService_ListRoutes_Handler,
},
{
MethodName: "SelectRoutes",
Handler: _DaemonService_SelectRoutes_Handler,
},
{
MethodName: "DeselectRoutes",
Handler: _DaemonService_DeselectRoutes_Handler,
},
{
MethodName: "DebugBundle",
Handler: _DaemonService_DebugBundle_Handler,
},
{
MethodName: "SetLogLevel",
Handler: _DaemonService_SetLogLevel_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "daemon.proto", Metadata: "daemon.proto",

175
client/server/debug.go Normal file
View File

@@ -0,0 +1,175 @@
package server
import (
"archive/zip"
"bufio"
"context"
"fmt"
"io"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.logFile == "console" {
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
}
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return nil, fmt.Errorf("create zip file: %w", err)
}
defer func() {
if err := bundlePath.Close(); err != nil {
log.Errorf("failed to close zip file: %v", err)
}
if err != nil {
if err2 := os.Remove(bundlePath.Name()); err2 != nil {
log.Errorf("Failed to remove zip file: %v", err2)
}
}
}()
archive := zip.NewWriter(bundlePath)
defer func() {
if err := archive.Close(); err != nil {
log.Errorf("failed to close archive writer: %v", err)
}
}()
if status := req.GetStatus(); status != "" {
filename := "status.txt"
if req.GetAnonymize() {
filename = "status.anon.txt"
}
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, filename); err != nil {
return nil, fmt.Errorf("add status file to zip: %w", err)
}
}
logFile, err := os.Open(s.logFile)
if err != nil {
return nil, fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("failed to close original log file: %v", err)
}
}()
filename := "client.log.txt"
var logReader io.Reader
errChan := make(chan error, 1)
if req.GetAnonymize() {
filename = "client.anon.log.txt"
var writer io.WriteCloser
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, errChan)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, filename); err != nil {
return nil, fmt.Errorf("add log file to zip: %w", err)
}
select {
case err := <-errChan:
if err != nil {
return nil, err
}
default:
}
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
}
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
scanner := bufio.NewScanner(reader)
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
status := s.statusRecorder.GetFullStatus()
seedFromStatus(anonymizer, &status)
defer func() {
if err := writer.Close(); err != nil {
log.Errorf("Failed to close writer: %v", err)
}
}()
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
errChan <- fmt.Errorf("write line to writer: %w", err)
return
}
}
if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("read line from scanner: %w", err)
return
}
}
// SetLogLevel sets the logging level for the server.
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
level, err := log.ParseLevel(req.Level.String())
if err != nil {
return nil, fmt.Errorf("invalid log level: %w", err)
}
log.SetLevel(level)
log.Infof("Log level set to %s", level.String())
return &proto.SetLogLevelResponse{}, nil
}
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,
Method: zip.Deflate,
}
writer, err := archive.CreateHeader(header)
if err != nil {
return fmt.Errorf("create zip file header: %w", err)
}
if _, err := io.Copy(writer, reader); err != nil {
return fmt.Errorf("write file to zip: %w", err)
}
return nil
}
func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
status.ManagementState.URL = a.AnonymizeURI(status.ManagementState.URL)
status.SignalState.URL = a.AnonymizeURI(status.SignalState.URL)
status.LocalPeerState.FQDN = a.AnonymizeDomain(status.LocalPeerState.FQDN)
for _, peer := range status.Peers {
a.AnonymizeDomain(peer.FQDN)
}
for _, nsGroup := range status.NSGroupStates {
for _, domain := range nsGroup.Domains {
a.AnonymizeDomain(domain)
}
}
for _, relay := range status.Relays {
if relay.URI != nil {
a.AnonymizeURI(relay.URI.String())
}
}
}

110
client/server/route.go Normal file
View File

@@ -0,0 +1,110 @@
package server
import (
"context"
"fmt"
"net/netip"
"sort"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto"
)
type selectRoute struct {
NetID string
Network netip.Prefix
Selected bool
}
// ListRoutes returns a list of all available routes.
func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.engine == nil {
return nil, fmt.Errorf("not connected")
}
routesMap := s.engine.GetClientRoutesWithNetID()
routeSelector := s.engine.GetRouteManager().GetRouteSelector()
var routes []*selectRoute
for id, rt := range routesMap {
if len(rt) == 0 {
continue
}
route := &selectRoute{
NetID: id,
Network: rt[0].Network,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
}
sort.Slice(routes, func(i, j int) bool {
iPrefix := routes[i].Network.Bits()
jPrefix := routes[j].Network.Bits()
if iPrefix == jPrefix {
iAddr := routes[i].Network.Addr()
jAddr := routes[j].Network.Addr()
if iAddr == jAddr {
return routes[i].NetID < routes[j].NetID
}
return iAddr.String() < jAddr.String()
}
return iPrefix < jPrefix
})
var pbRoutes []*proto.Route
for _, route := range routes {
pbRoutes = append(pbRoutes, &proto.Route{
ID: route.NetID,
Network: route.Network.String(),
Selected: route.Selected,
})
}
return &proto.ListRoutesResponse{
Routes: pbRoutes,
}, nil
}
// SelectRoutes selects specific routes based on the client request.
func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.SelectAllRoutes()
} else {
if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("select routes: %w", err)
}
}
routeManager.TriggerSelection(s.engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil
}
// DeselectRoutes deselects specific routes based on the client request.
func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.DeselectAllRoutes()
} else {
if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err)
}
}
routeManager.TriggerSelection(s.engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil
}

View File

@@ -15,15 +15,15 @@ import (
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
@@ -57,6 +57,8 @@ type Server struct {
config *internal.Config config *internal.Config
proto.UnimplementedDaemonServiceServer proto.UnimplementedDaemonServiceServer
engine *internal.Engine
statusRecorder *peer.Status statusRecorder *peer.Status
sessionWatcher *internal.SessionWatcher sessionWatcher *internal.SessionWatcher
@@ -141,8 +143,11 @@ func (s *Server) Start() error {
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire) s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
} }
engineChan := make(chan *internal.Engine, 1)
go s.watchEngine(ctx, engineChan)
if !config.DisableAutoConnect { if !config.DisableAutoConnect {
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
} }
return nil return nil
@@ -153,6 +158,7 @@ func (s *Server) Start() error {
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe, mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
engineChan chan<- *internal.Engine,
) { ) {
backOff := getConnectWithBackoff(ctx) backOff := getConnectWithBackoff(ctx)
retryStarted := false retryStarted := false
@@ -182,7 +188,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error { runOperation := func() error {
log.Tracef("running client connection") log.Tracef("running client connection")
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe) err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
if err != nil { if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err) log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
} }
@@ -562,7 +568,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) engineChan := make(chan *internal.Engine, 1)
go s.watchEngine(ctx, engineChan)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
return &proto.UpResponse{}, nil return &proto.UpResponse{}, nil
} }
@@ -579,6 +588,8 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle) state.Set(internal.StatusIdle)
s.engine = nil
return &proto.DownResponse{}, nil return &proto.DownResponse{}, nil
} }
@@ -661,7 +672,6 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
PreSharedKey: preSharedKey, PreSharedKey: preSharedKey,
}, nil }, nil
} }
func (s *Server) onSessionExpire() { func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp() isUIActive := internal.CheckUIApp()
@@ -673,6 +683,22 @@ func (s *Server) onSessionExpire() {
} }
} }
// watchEngine watches the engine channel and updates the engine state
func (s *Server) watchEngine(ctx context.Context, engineChan chan *internal.Engine) {
log.Tracef("Started watching engine")
for {
select {
case <-ctx.Done():
s.engine = nil
log.Tracef("Stopped watching engine")
return
case engine := <-engineChan:
log.Tracef("Received engine from watcher")
s.engine = engine
}
}
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{ pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{}, ManagementState: &proto.ManagementState{},

View File

@@ -2,11 +2,12 @@ package server
import ( import (
"context" "context"
"github.com/netbirdio/management-integrations/integrations"
"net" "net"
"testing" "testing"
"time" "time"
"github.com/netbirdio/management-integrations/integrations"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@@ -69,7 +70,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s") t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1") t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, nil)
if counter < 3 { if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter) t.Fatalf("expected counter > 2, got %d", counter)
} }

View File

@@ -1,7 +1,5 @@
//go:build !(linux && 386) //go:build !(linux && 386)
// +build !linux !386
// skipping linux 32 bits build and tests
package main package main
import ( import (
@@ -58,14 +56,23 @@ func main() {
var showSettings bool var showSettings bool
flag.BoolVar(&showSettings, "settings", false, "run settings windows") flag.BoolVar(&showSettings, "settings", false, "run settings windows")
var showRoutes bool
flag.BoolVar(&showRoutes, "routes", false, "run routes windows")
var errorMSG string
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")
flag.Parse() flag.Parse()
a := app.NewWithID("NetBird") a := app.NewWithID("NetBird")
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG)) a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
client := newServiceClient(daemonAddr, a, showSettings) if errorMSG != "" {
if showSettings { showErrorMSG(errorMSG)
return
}
client := newServiceClient(daemonAddr, a, showSettings, showRoutes)
if showSettings || showRoutes {
a.Run() a.Run()
} else { } else {
if err := checkPIDFile(); err != nil { if err := checkPIDFile(); err != nil {
@@ -128,6 +135,7 @@ type serviceClient struct {
mVersionDaemon *systray.MenuItem mVersionDaemon *systray.MenuItem
mUpdate *systray.MenuItem mUpdate *systray.MenuItem
mQuit *systray.MenuItem mQuit *systray.MenuItem
mRoutes *systray.MenuItem
// application with main windows. // application with main windows.
app fyne.App app fyne.App
@@ -152,12 +160,15 @@ type serviceClient struct {
daemonVersion string daemonVersion string
updateIndicationLock sync.Mutex updateIndicationLock sync.Mutex
isUpdateIconActive bool isUpdateIconActive bool
showRoutes bool
wRoutes fyne.Window
} }
// newServiceClient instance constructor // newServiceClient instance constructor
// //
// This constructor also builds the UI elements for the settings window. // This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient { func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes bool) *serviceClient {
s := &serviceClient{ s := &serviceClient{
ctx: context.Background(), ctx: context.Background(),
addr: addr, addr: addr,
@@ -165,6 +176,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
sendNotification: false, sendNotification: false,
showSettings: showSettings, showSettings: showSettings,
showRoutes: showRoutes,
update: version.NewUpdate(), update: version.NewUpdate(),
} }
@@ -184,14 +196,16 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
} }
if showSettings { if showSettings {
s.showUIElements() s.showSettingsUI()
return s return s
} else if showRoutes {
s.showRoutesUI()
} }
return s return s
} }
func (s *serviceClient) showUIElements() { func (s *serviceClient) showSettingsUI() {
// add settings window UI elements. // add settings window UI elements.
s.wSettings = s.app.NewWindow("NetBird Settings") s.wSettings = s.app.NewWindow("NetBird Settings")
s.iMngURL = widget.NewEntry() s.iMngURL = widget.NewEntry()
@@ -209,6 +223,18 @@ func (s *serviceClient) showUIElements() {
s.wSettings.Show() s.wSettings.Show()
} }
// showErrorMSG opens a fyne app window to display the supplied message
func showErrorMSG(msg string) {
app := app.New()
w := app.NewWindow("NetBird Error")
content := widget.NewLabel(msg)
content.Wrapping = fyne.TextWrapWord
w.SetContent(content)
w.Resize(fyne.NewSize(400, 100))
w.Show()
app.Run()
}
// getSettingsForm to embed it into settings window. // getSettingsForm to embed it into settings window.
func (s *serviceClient) getSettingsForm() *widget.Form { func (s *serviceClient) getSettingsForm() *widget.Form {
return &widget.Form{ return &widget.Form{
@@ -397,6 +423,7 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetTitle("Connected") s.mStatus.SetTitle("Connected")
s.mUp.Disable() s.mUp.Disable()
s.mDown.Enable() s.mDown.Enable()
s.mRoutes.Enable()
systrayIconState = true systrayIconState = true
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
s.connected = false s.connected = false
@@ -409,6 +436,7 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetTitle("Disconnected") s.mStatus.SetTitle("Disconnected")
s.mDown.Disable() s.mDown.Disable()
s.mUp.Enable() s.mUp.Enable()
s.mRoutes.Disable()
systrayIconState = false systrayIconState = false
} }
@@ -464,9 +492,11 @@ func (s *serviceClient) onTrayReady() {
s.mUp = systray.AddMenuItem("Connect", "Connect") s.mUp = systray.AddMenuItem("Connect", "Connect")
s.mDown = systray.AddMenuItem("Disconnect", "Disconnect") s.mDown = systray.AddMenuItem("Disconnect", "Disconnect")
s.mDown.Disable() s.mDown.Disable()
s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Wiretrustee Admin Panel") s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Netbird Admin Panel")
systray.AddSeparator() systray.AddSeparator()
s.mSettings = systray.AddMenuItem("Settings", "Settings of the application") s.mSettings = systray.AddMenuItem("Settings", "Settings of the application")
s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window")
s.mRoutes.Disable()
systray.AddSeparator() systray.AddSeparator()
s.mAbout = systray.AddMenuItem("About", "About") s.mAbout = systray.AddMenuItem("About", "About")
@@ -504,16 +534,22 @@ func (s *serviceClient) onTrayReady() {
case <-s.mAdminPanel.ClickedCh: case <-s.mAdminPanel.ClickedCh:
err = open.Run(s.adminURL) err = open.Run(s.adminURL)
case <-s.mUp.ClickedCh: case <-s.mUp.ClickedCh:
s.mUp.Disabled()
go func() { go func() {
defer s.mUp.Enable()
err := s.menuUpClick() err := s.menuUpClick()
if err != nil { if err != nil {
s.runSelfCommand("error-msg", err.Error())
return return
} }
}() }()
case <-s.mDown.ClickedCh: case <-s.mDown.ClickedCh:
s.mDown.Disable()
go func() { go func() {
defer s.mDown.Enable()
err := s.menuDownClick() err := s.menuDownClick()
if err != nil { if err != nil {
s.runSelfCommand("error-msg", err.Error())
return return
} }
}() }()
@@ -521,24 +557,8 @@ func (s *serviceClient) onTrayReady() {
s.mSettings.Disable() s.mSettings.Disable()
go func() { go func() {
defer s.mSettings.Enable() defer s.mSettings.Enable()
proc, err := os.Executable() defer s.getSrvConfig()
if err != nil { s.runSelfCommand("settings", "true")
log.Errorf("show settings: %v", err)
return
}
cmd := exec.Command(proc, "--settings=true")
out, err := cmd.CombinedOutput()
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
log.Errorf("start settings UI: %v, %s", err, string(out))
return
}
if len(out) != 0 {
log.Info("settings change:", string(out))
}
// update config in systray when settings windows closed
s.getSrvConfig()
}() }()
case <-s.mQuit.ClickedCh: case <-s.mQuit.ClickedCh:
systray.Quit() systray.Quit()
@@ -548,6 +568,12 @@ func (s *serviceClient) onTrayReady() {
if err != nil { if err != nil {
log.Errorf("%s", err) log.Errorf("%s", err)
} }
case <-s.mRoutes.ClickedCh:
s.mRoutes.Disable()
go func() {
defer s.mRoutes.Enable()
s.runSelfCommand("routes", "true")
}()
} }
if err != nil { if err != nil {
log.Errorf("process connection: %v", err) log.Errorf("process connection: %v", err)
@@ -556,6 +582,24 @@ func (s *serviceClient) onTrayReady() {
}() }()
} }
func (s *serviceClient) runSelfCommand(command, arg string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("show %s failed with error: %v", command, err)
return
}
cmd := exec.Command(proc, fmt.Sprintf("--%s=%s", command, arg))
out, err := cmd.CombinedOutput()
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
log.Errorf("start %s UI: %v, %s", command, err, string(out))
return
}
if len(out) != 0 {
log.Infof("command %s executed: %s", command, string(out))
}
}
func normalizedVersion(version string) string { func normalizedVersion(version string) string {
versionString := version versionString := version
if unicode.IsDigit(rune(versionString[0])) { if unicode.IsDigit(rune(versionString[0])) {

203
client/ui/route.go Normal file
View File

@@ -0,0 +1,203 @@
//go:build !(linux && 386)
package main
import (
"fmt"
"strings"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/layout"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
func (s *serviceClient) showRoutesUI() {
s.wRoutes = s.app.NewWindow("NetBird Routes")
grid := container.New(layout.NewGridLayout(2))
go s.updateRoutes(grid)
routeCheckContainer := container.NewVBox()
routeCheckContainer.Add(grid)
scrollContainer := container.NewVScroll(routeCheckContainer)
scrollContainer.SetMinSize(fyne.NewSize(200, 300))
buttonBox := container.NewHBox(
layout.NewSpacer(),
widget.NewButton("Refresh", func() {
s.updateRoutes(grid)
}),
widget.NewButton("Select all", func() {
s.selectAllRoutes()
s.updateRoutes(grid)
}),
widget.NewButton("Deselect All", func() {
s.deselectAllRoutes()
s.updateRoutes(grid)
}),
layout.NewSpacer(),
)
content := container.NewBorder(nil, buttonBox, nil, nil, scrollContainer)
s.wRoutes.SetContent(content)
s.wRoutes.Show()
s.startAutoRefresh(5*time.Second, grid)
}
func (s *serviceClient) updateRoutes(grid *fyne.Container) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
grid.Objects = nil
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
grid.Add(idHeader)
grid.Add(networkHeader)
for _, route := range routes {
r := route
checkBox := widget.NewCheck(r.ID, func(checked bool) {
s.selectRoute(r.ID, checked)
})
checkBox.Checked = route.Selected
checkBox.Resize(fyne.NewSize(20, 20))
checkBox.Refresh()
grid.Add(checkBox)
grid.Add(widget.NewLabel(r.Network))
}
s.wRoutes.Content().Refresh()
}
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return nil, fmt.Errorf("get client: %v", err)
}
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list routes: %v", err)
}
return resp.Routes, nil
}
func (s *serviceClient) selectRoute(id string, checked bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
req := &proto.SelectRoutesRequest{
RouteIDs: []string{id},
Append: checked,
}
if checked {
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select route: %v", err)
s.showError(fmt.Errorf("failed to select route: %v", err))
return
}
log.Infof("Route %s selected", id)
} else {
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect route: %v", err)
s.showError(fmt.Errorf("failed to deselect route: %v", err))
return
}
log.Infof("Route %s deselected", id)
}
}
func (s *serviceClient) selectAllRoutes() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select all routes: %v", err)
s.showError(fmt.Errorf("failed to select all routes: %v", err))
return
}
log.Debug("All routes selected")
}
func (s *serviceClient) deselectAllRoutes() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect all routes: %v", err)
s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
return
}
log.Debug("All routes deselected")
}
func (s *serviceClient) showError(err error) {
wrappedMessage := wrapText(err.Error(), 50)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
}
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
s.updateRoutes(grid)
}
}()
s.wRoutes.SetOnClosed(func() {
ticker.Stop()
})
}
// wrapText inserts newlines into the text to ensure that each line is
// no longer than 'lineLength' runes.
func wrapText(text string, lineLength int) string {
var sb strings.Builder
var currentLineLength int
for _, runeValue := range text {
sb.WriteRune(runeValue)
currentLineLength++
if currentLineLength >= lineLength || runeValue == '\n' {
sb.WriteRune('\n')
currentLineLength = 0
}
}
return sb.String()
}

8
go.mod
View File

@@ -21,8 +21,8 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
golang.org/x/crypto v0.18.0 golang.org/x/crypto v0.21.0
golang.org/x/sys v0.16.0 golang.org/x/sys v0.18.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
@@ -82,10 +82,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.20.0 golang.org/x/net v0.23.0
golang.org/x/oauth2 v0.8.0 golang.org/x/oauth2 v0.8.0
golang.org/x/sync v0.3.0 golang.org/x/sync v0.3.0
golang.org/x/term v0.16.0 golang.org/x/term v0.18.0
google.golang.org/api v0.126.0 google.golang.org/api v0.126.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/sqlite v1.5.3 gorm.io/driver/sqlite v1.5.3

12
go.sum
View File

@@ -581,8 +581,9 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220314234659-1baeb1ce4c0b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -669,8 +670,9 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -761,16 +763,18 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@@ -11,6 +11,7 @@ import (
"net/netip" "net/netip"
"reflect" "reflect"
"regexp" "regexp"
"runtime/debug"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -46,6 +47,8 @@ const (
DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerLoginExpiration = 24 * time.Hour
) )
type userLoggedInOnce bool
type ExternalCacheManager cache.CacheInterface[*idp.UserData] type ExternalCacheManager cache.CacheInterface[*idp.UserData]
func cacheEntryExpiration() time.Duration { func cacheEntryExpiration() time.Duration {
@@ -74,7 +77,7 @@ type AccountManager interface {
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(accountID string) ([]*User, error) ListUsers(accountID string) ([]*User, error)
GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error
DeletePeer(accountID, peerID, userID string) error DeletePeer(accountID, peerID, userID string) error
UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
GetNetworkMap(peerID string) (*NetworkMap, error) GetNetworkMap(peerID string) (*NetworkMap, error)
@@ -115,8 +118,8 @@ type AccountManager interface {
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error) GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager GetExternalCacheManager() ExternalCacheManager
@@ -128,6 +131,8 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
GroupValidation(accountId string, groups []string) (bool, error) GroupValidation(accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error) GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
CancelPeerRoutines(peer *nbpeer.Peer) error
} }
type DefaultAccountManager struct { type DefaultAccountManager struct {
@@ -384,6 +389,8 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise // GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
log.Debugf("GetNetworkMap with trace: %s", string(debug.Stack()))
peer := a.Peers[peerID] peer := a.Peers[peerID]
if peer == nil { if peer == nil {
return &NetworkMap{ return &NetworkMap{
@@ -956,7 +963,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
} }
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -1007,7 +1014,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) { return func() (time.Duration, bool) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -1092,19 +1099,21 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
} }
delete(userData, idp.UnsetAccountID) delete(userData, idp.UnsetAccountID)
rcvdUsers := 0
for accountID, users := range userData { for accountID, users := range userData {
rcvdUsers += len(users)
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration())) err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil { if err != nil {
return err return err
} }
} }
log.Infof("warmed up IDP cache with %d entries", len(userData)) log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData))
return nil return nil
} }
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error { func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -1263,7 +1272,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users)) users := make(map[string]userLoggedInOnce, len(account.Users))
// ignore service users and users provisioned by integrations than are never logged in // ignore service users and users provisioned by integrations than are never logged in
for _, user := range account.Users { for _, user := range account.Users {
if user.IsServiceUser { if user.IsServiceUser {
@@ -1272,7 +1281,7 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou
if user.Issued == UserIssuedIntegration { if user.Issued == UserIssuedIntegration {
continue continue
} }
users[user.Id] = struct{}{} users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
} }
log.Debugf("looking up user %s of account %s in cache", userID, account.Id) log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id) userData, err := am.lookupCache(users, account.Id)
@@ -1345,22 +1354,57 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo
} }
} }
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) { func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false) var data []*idp.UserData
var err error
maxAttempts := 2
data, err = am.getAccountFromCache(accountID, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userDataMap := make(map[string]struct{}) for attempt := 1; attempt <= maxAttempts; attempt++ {
if am.isCacheFresh(accountUsers, data) {
return data, nil
}
if attempt > 1 {
time.Sleep(200 * time.Millisecond)
}
log.Infof("refreshing cache for account %s", accountID)
data, err = am.refreshCache(accountID)
if err != nil {
return nil, err
}
if attempt == maxAttempts {
log.Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts)
}
}
return data, nil
}
// isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status
func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool {
userDataMap := make(map[string]*idp.UserData, len(data))
for _, datum := range data { for _, datum := range data {
userDataMap[datum.ID] = struct{}{} userDataMap[datum.ID] = datum
} }
// the accountUsers ID list of non integration users from store, we check if cache has all of them // the accountUsers ID list of non integration users from store, we check if cache has all of them
// as result of for loop knownUsersCount will have number of users are not presented in the cashed // as result of for loop knownUsersCount will have number of users are not presented in the cashed
knownUsersCount := len(accountUsers) knownUsersCount := len(accountUsers)
for user := range accountUsers { for user, loggedInOnce := range accountUsers {
if _, ok := userDataMap[user]; ok { if datum, ok := userDataMap[user]; ok {
// check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed
if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple
log.Infof("user %s has a pending invite and has logged in once, cache invalid", user)
return false
}
knownUsersCount-- knownUsersCount--
continue continue
} }
@@ -1369,15 +1413,11 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
// if we know users that are not yet in cache more likely cache is outdated // if we know users that are not yet in cache more likely cache is outdated
if knownUsersCount > 0 { if knownUsersCount > 0 {
log.Debugf("cache doesn't know about %d users from store, reloading", knownUsersCount) log.Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount)
// reload cache once avoiding loops return false
data, err = am.refreshCache(accountID)
if err != nil {
return nil, err
}
} }
return data, err return true
} }
func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) error { func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) error {
@@ -1425,29 +1465,14 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account,
} }
// handleExistingUserAccount handles existing User accounts and update its domain attributes. // handleExistingUserAccount handles existing User accounts and update its domain attributes.
//
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
func (am *DefaultAccountManager) handleExistingUserAccount( func (am *DefaultAccountManager) handleExistingUserAccount(
existingAcc *Account, existingAcc *Account,
domainAcc *Account, primaryDomain bool,
claims jwtclaims.AuthorizationClaims, claims jwtclaims.AuthorizationClaims,
) error { ) error {
var err error err := am.updateAccountDomainAttributes(existingAcc, claims, primaryDomain)
if err != nil {
if domainAcc != nil && existingAcc.Id != domainAcc.Id { return err
err = am.updateAccountDomainAttributes(existingAcc, claims, false)
if err != nil {
return err
}
} else {
err = am.updateAccountDomainAttributes(existingAcc, claims, true)
if err != nil {
return err
}
} }
// we should register the account ID to this user's metadata in our IDP manager // we should register the account ID to this user's metadata in our IDP manager
@@ -1547,7 +1572,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
return err return err
} }
unlock := am.Store.AcquireAccountLock(account.Id) unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock() defer unlock()
account, err = am.Store.GetAccountByUser(user.Id) account, err = am.Store.GetAccountByUser(user.Id)
@@ -1630,7 +1655,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
unlock := am.Store.AcquireAccountLock(newAcc.Id) unlock := am.Store.AcquireAccountWriteLock(newAcc.Id)
alreadyUnlocked := false alreadyUnlocked := false
defer func() { defer func() {
if !alreadyUnlocked { if !alreadyUnlocked {
@@ -1649,7 +1674,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if !user.IsServiceUser { if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(account, claims.UserId) err = am.redeemInvite(account, claims.UserId)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -1781,12 +1806,33 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
account, err := am.Store.GetAccountByUser(claims.UserId) account, err := am.Store.GetAccountByUser(claims.UserId)
if err == nil { if err == nil {
err = am.handleExistingUserAccount(account, domainAccount, claims) unlockAccount := am.Store.AcquireAccountWriteLock(account.Id)
defer unlockAccount()
account, err = am.Store.GetAccountByUser(claims.UserId)
if err != nil {
return nil, err
}
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
err = am.handleExistingUserAccount(account, primaryDomain, claims)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account, nil return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil {
unlockAccount := am.Store.AcquireAccountWriteLock(domainAccount.Id)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain)
if err != nil {
return nil, err
}
}
return am.handleNewUserAccount(domainAccount, claims) return am.handleNewUserAccount(domainAccount, claims)
} else { } else {
// other error // other error
@@ -1794,6 +1840,56 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
} }
} }
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
if err != nil {
return nil, nil, err
}
unlock := am.Store.AcquireAccountReadLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
}
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
if err != nil {
return nil, nil, mapError(err)
}
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return peer, netMap, nil
}
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
err = am.MarkPeerConnected(peer.Key, false, nil, account)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peer.Key, err)
}
return nil
}
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() // GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil return am.peersUpdateManager.GetAllConnectedPeers(), nil

View File

@@ -1655,7 +1655,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "") _, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@@ -1666,7 +1666,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
@@ -1732,8 +1735,10 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}, },
} }
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger // when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil) err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1745,7 +1750,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "") _, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@@ -1756,7 +1761,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}

View File

@@ -35,7 +35,7 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) { func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -57,7 +57,7 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error { func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)

View File

@@ -12,7 +12,7 @@ import (
// GetEvents returns a list of activity events of an account // GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) { func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)

View File

@@ -279,8 +279,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
return unlock return unlock
} }
// AcquireAccountLock acquires account lock and returns a function that releases the lock // AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) { func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID) log.Debugf("acquiring lock for account %s", accountID)
start := time.Now() start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
@@ -295,6 +295,12 @@ func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
return unlock return unlock
} }
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(accountID)
}
func (s *FileStore) SaveAccount(account *Account) error { func (s *FileStore) SaveAccount(account *Account) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -572,6 +578,18 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
return account.Copy(), nil return account.Copy(), nil
} }
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
}
return accountID, nil
}
// GetInstallationID returns the installation ID from the store // GetInstallationID returns the installation ID from the store
func (s *FileStore) GetInstallationID() string { func (s *FileStore) GetInstallationID() string {
return s.InstallationID return s.InstallationID

View File

@@ -22,7 +22,7 @@ func (e *GroupLinkError) Error() string {
// GetGroup object of the peers // GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -76,7 +76,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -109,7 +109,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
// SaveGroup object of the peers // SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -214,7 +214,7 @@ func difference(a, b []string) []string {
// DeleteGroup object of the peers // DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountLock(accountId) unlock := am.Store.AcquireAccountWriteLock(accountId)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountId) account, err := am.Store.GetAccount(accountId)
@@ -323,7 +323,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
// ListGroups objects of the peers // ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -341,7 +341,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)

View File

@@ -134,9 +134,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return err return err
} }
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()}) peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
if err != nil { if err != nil {
return mapError(err) return err
} }
err = s.sendInitialSync(peerKey, peer, netMap, srv) err = s.sendInitialSync(peerKey, peer, netMap, srv)
@@ -149,11 +149,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.ephemeralManager.OnPeerConnected(peer) s.ephemeralManager.OnPeerConnected(peer)
err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
}
if s.config.TURNConfig.TimeBasedCredentials { if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID) s.turnCredentialsManager.SetupRefresh(peer.ID)
} }
@@ -207,7 +202,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) { func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID) s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.MarkPeerConnected(peer.Key, false, nil) _ = s.accountManager.CancelPeerRoutines(peer)
s.ephemeralManager.OnPeerDisconnected(peer) s.ephemeralManager.OnPeerDisconnected(peer)
} }

View File

@@ -181,8 +181,8 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
} }
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
err := json.NewEncoder(w).Encode(toResponseBody(key)) err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)

View File

@@ -198,8 +198,6 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
serviceUser := r.URL.Query().Get("service_user") serviceUser := r.URL.Query().Get("service_user")
log.Debugf("UserCount: %v", len(data))
users := make([]*api.User, 0) users := make([]*api.User, 0)
for _, r := range data { for _, r := range data {
if r.NonDeletable { if r.NonDeletable {

View File

@@ -20,8 +20,8 @@ type ErrorResponse struct {
// WriteJSONObject simply writes object to the HTTP response in JSON format // WriteJSONObject simply writes object to the HTTP response in JSON format
func WriteJSONObject(w http.ResponseWriter, obj interface{}) { func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(obj) err := json.NewEncoder(w).Encode(obj)
if err != nil { if err != nil {
WriteError(err, w) WriteError(err, w)
@@ -63,8 +63,8 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
// WriteErrorResponse prepares and writes an error response i nJSON // WriteErrorResponse prepares and writes an error response i nJSON
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) { func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
w.WriteHeader(httpStatus)
w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(httpStatus)
err := json.NewEncoder(w).Encode(&ErrorResponse{ err := json.NewEncoder(w).Encode(&ErrorResponse{
Message: errMsg, Message: errMsg,
Code: httpStatus, Code: httpStatus,

View File

@@ -31,7 +31,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
return errors.New("invalid groups") return errors.New("invalid groups")
} }
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
a, err := am.Store.GetAccountByUser(userID) a, err := am.Store.GetAccountByUser(userID)

View File

@@ -13,6 +13,7 @@ type AuthorizationClaims struct {
Domain string Domain string
DomainCategory string DomainCategory string
LastLogin time.Time LastLogin time.Time
Invited bool
Raw jwt.MapClaims Raw jwt.MapClaims
} }

View File

@@ -20,6 +20,8 @@ const (
UserIDClaim = "sub" UserIDClaim = "sub"
// LastLoginSuffix claim for the last login // LastLoginSuffix claim for the last login
LastLoginSuffix = "nb_last_login" LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited"
) )
// ExtractClaims Extract function type // ExtractClaims Extract function type
@@ -100,6 +102,10 @@ func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
if ok { if ok {
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string)) jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
} }
invitedBool, ok := claims[c.authAudience+Invited]
if ok {
jwtClaims.Invited = invitedBool.(bool)
}
return jwtClaims return jwtClaims
} }

View File

@@ -30,6 +30,10 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience st
if claims.LastLogin != (time.Time{}) { if claims.LastLogin != (time.Time{}) {
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout) claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
} }
if claims.Invited {
claimMaps[audience+Invited] = true
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil) r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed") require.NoError(t, err, "creating testing request failed")
@@ -59,12 +63,14 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
AccountId: "testAcc", AccountId: "testAcc",
LastLogin: lastLogin, LastLogin: lastLogin,
DomainCategory: "public", DomainCategory: "public",
Invited: true,
Raw: jwt.MapClaims{ Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com", "https://login/wt_account_domain": "test.com",
"https://login/wt_account_domain_category": "public", "https://login/wt_account_domain_category": "public",
"https://login/wt_account_id": "testAcc", "https://login/wt_account_id": "testAcc",
"https://login/nb_last_login": lastLogin.Format(layout), "https://login/nb_last_login": lastLogin.Format(layout),
"sub": "test", "sub": "test",
"https://login/" + Invited: true,
}, },
}, },
testingFunc: require.EqualValues, testingFunc: require.EqualValues,

View File

@@ -0,0 +1,204 @@
package migration
import (
"database/sql"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"net"
"strings"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
)
// MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding.
// T is the type of the model that contains the field to be migrated.
// S is the type of the field to be migrated.
func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) error {
oldColumnName := fieldName
newColumnName := fieldName + "_tmp"
var model T
if !db.Migrator().HasTable(&model) {
log.Debugf("Table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
var item string
if err := db.Model(model).Select(oldColumnName).First(&item).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debugf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("fetch first record: %w", err)
}
var js json.RawMessage
var syntaxError *json.SyntaxError
err = json.Unmarshal([]byte(item), &js)
if err == nil || !errors.As(err, &syntaxError) {
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
return nil
}
if err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
return fmt.Errorf("add column %s: %w", newColumnName, err)
}
var rows []map[string]any
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
var field S
str, ok := row[oldColumnName].(string)
if !ok {
return fmt.Errorf("type assertion failed")
}
reader := strings.NewReader(str)
if err := gob.NewDecoder(reader).Decode(&field); err != nil {
return fmt.Errorf("gob decode error: %w", err)
}
jsonValue, err := json.Marshal(field)
if err != nil {
return fmt.Errorf("re-encode to JSON: %w", err)
}
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
return fmt.Errorf("update row: %w", err)
}
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
}
return nil
}); err != nil {
return err
}
log.Infof("Migration of %s.%s from gob to json completed", tableName, fieldName)
return nil
}
// MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding.
// T is the type of the model that contains the field to be migrated.
func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error {
oldColumnName := fieldName
newColumnName := fieldName + "_tmp"
var model T
if !db.Migrator().HasTable(&model) {
log.Printf("Table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
var item sql.NullString
if err := db.Model(&model).Select(oldColumnName).First(&item).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Printf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("fetch first record: %w", err)
}
if item.Valid {
var js json.RawMessage
var syntaxError *json.SyntaxError
err = json.Unmarshal([]byte(item.String), &js)
if err == nil || !errors.As(err, &syntaxError) {
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
return nil
}
}
if err := db.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
return fmt.Errorf("add column %s: %w", newColumnName, err)
}
var rows []map[string]any
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Printf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
var blobValue string
if columnValue := row[oldColumnName]; columnValue != nil {
value, ok := columnValue.(string)
if !ok {
return fmt.Errorf("type assertion failed")
}
blobValue = value
}
columnIpValue := net.IP(blobValue)
if net.ParseIP(columnIpValue.String()) == nil {
log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
columnIpValue = net.IPv6loopback
}
jsonValue, err := json.Marshal(columnIpValue)
if err != nil {
return fmt.Errorf("re-encode to JSON: %w", err)
}
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
return fmt.Errorf("update row: %w", err)
}
}
if indexName != "" {
if err := tx.Migrator().DropIndex(&model, indexName); err != nil {
return fmt.Errorf("drop index %s: %w", indexName, err)
}
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
}
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
}
return nil
}); err != nil {
return err
}
log.Printf("Migration of %s.%s from blob to json completed", tableName, fieldName)
return nil
}

View File

@@ -0,0 +1,161 @@
package migration_test
import (
"encoding/gob"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
)
func setupDatabase(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
PrepareStmt: true,
})
require.NoError(t, err, "Failed to open database")
return db
}
func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
db := setupDatabase(t)
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
require.NoError(t, err, "Migration should not fail for an empty database")
}
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &route.Route{})
require.NoError(t, err, "Failed to auto-migrate tables")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
type network struct {
server.Network
Net net.IPNet `gorm:"serializer:gob"`
}
type account struct {
server.Account
Network *network `gorm:"embedded;embeddedPrefix:network_"`
}
err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error
require.NoError(t, err, "Failed to insert Gob data")
var gobStr string
err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error
assert.NoError(t, err, "Failed to fetch Gob data")
err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet)
require.NoError(t, err, "Failed to decode Gob data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
require.NoError(t, err, "Migration should not fail with Gob data")
var jsonStr string
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated")
}
func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &route.Route{})
require.NoError(t, err, "Failed to auto-migrate tables")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error
require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
require.NoError(t, err, "Migration should not fail with JSON data")
var jsonStr string
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged")
}
func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
db := setupDatabase(t)
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
require.NoError(t, err, "Migration should not fail for an empty database")
}
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables")
type location struct {
nbpeer.Location
ConnectionIP net.IP
}
type peer struct {
nbpeer.Peer
Location location `gorm:"embedded;embeddedPrefix:location_"`
}
type account struct {
server.Account
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
}
err = db.Save(&account{
Account: server.Account{Id: "123"},
Peers: []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}},
).Error
require.NoError(t, err, "Failed to insert blob data")
var blobValue string
err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error
assert.NoError(t, err, "Failed to fetch blob data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
require.NoError(t, err, "Migration should not fail with net.IP blob data")
var jsonStr string
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be migrated")
}
func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.Account{
Id: "1234",
PeersG: []nbpeer.Peer{
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}},
).Error
require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
require.NoError(t, err, "Migration should not fail with net.IP JSON data")
var jsonStr string
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
}

View File

@@ -22,80 +22,93 @@ type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error) ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
DeletePeerFunc func(accountID, peerKey, userID string) error SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) DeletePeerFunc func(accountID, peerKey, userID string) error
GetPeerNetworkFunc func(peerKey string) (*server.Network, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *group.Group) error GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
DeleteGroupFunc func(accountID, userId, groupID string) error SaveGroupFunc func(accountID, userID string, group *group.Group) error
ListGroupsFunc func(accountID string) ([]*group.Group, error) DeleteGroupFunc func(accountID, userId, groupID string) error
GroupAddPeerFunc func(accountID, groupID, peerID string) error ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupDeletePeerFunc func(accountID, groupID, peerID string) error GroupAddPeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error GroupDeletePeerFunc func(accountID, groupID, peerID string) error
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error) DeleteRuleFunc func(accountID, ruleID, userID string) error
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
DeletePolicyFunc func(accountID, policyID, userID string) error SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) DeletePolicyFunc func(accountID, policyID, userID string) error
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
MarkPATUsedFunc func(pat string) error GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error MarkPATUsedFunc func(pat string) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID, userID string) error SaveRouteFunc func(accountID, userID string, route *route.Route) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) DeleteRouteFunc func(accountID, routeID, userID string) error
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error) DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
DeleteAccountFunc func(accountID, userID string) error CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
GetDNSDomainFunc func() string DeleteAccountFunc func(accountID, userID string) error
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetDNSDomainFunc func() string
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
GetAllConnectedPeersFunc func() (map[string]struct{}, error) InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
HasConnectedChannelFunc func(peerID string) bool GetAllConnectedPeersFunc func() (map[string]struct{}, error)
GetExternalCacheManagerFunc func() server.ExternalCacheManager HasConnectedChannelFunc func(peerID string) bool
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error) GetExternalCacheManagerFunc func() server.ExternalCacheManager
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
GetIdpManagerFunc func() idp.Manager ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error) GroupValidationFunc func(accountId string, groups []string) (bool, error)
} }
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(peerPubKey, realIP)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
// TODO implement me
panic("implement me")
}
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
approvedPeers := make(map[string]struct{}) approvedPeers := make(map[string]struct{})
for id := range account.Peers { for id := range account.Peers {
@@ -180,7 +193,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
} }
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error { func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error {
if am.MarkPeerConnectedFunc != nil { if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(peerKey, connected, realIP) return am.MarkPeerConnectedFunc(peerKey, connected, realIP)
} }
@@ -626,7 +639,7 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *
} }
// SyncPeer mocks SyncPeer of the AccountManager interface // SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) { func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.SyncPeerFunc != nil { if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(sync) return am.SyncPeerFunc(sync)
} }

View File

@@ -19,7 +19,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -47,7 +47,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -94,7 +94,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
// SaveNameServerGroup saves nameserver group // SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
if nsGroupToSave == nil { if nsGroupToSave == nil {
@@ -129,7 +129,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
// DeleteNameServerGroup deletes nameserver group with nsGroupID // DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -159,7 +159,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
// ListNameServerGroups returns a list of nameserver groups from account // ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) { func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)

View File

@@ -36,7 +36,7 @@ type NetworkMap struct {
type Network struct { type Network struct {
Identifier string `json:"id"` Identifier string `json:"id"`
Net net.IPNet `gorm:"serializer:gob"` Net net.IPNet `gorm:"serializer:json"`
Dns string Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps. // Used to synchronize state to the client apps.

View File

@@ -88,21 +88,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
} }
// MarkPeerConnected marks peer as connected (true) or disconnected (false) // MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP) error { func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error {
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return err
}
peer, err := account.FindPeerByPubKey(peerPubKey) peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil { if err != nil {
return err return err
@@ -156,7 +142,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -278,7 +264,7 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
// DeletePeer removes peer from the account by its IP // DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error { func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -362,7 +348,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
} }
unlock := am.Store.AcquireAccountLock(account.Id) unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock() defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
@@ -374,14 +360,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil { if am.idpManager != nil {
userdata, err := am.lookupUserInCache(userID, account) userdata, err := am.lookupUserInCache(userID, account)
if err == nil { if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
} }
} }
} }
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountLock (e.g., database is slow) // Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again. // and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration. // We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry. // The connecting peer should be able to recover with a retry.
@@ -518,25 +504,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
} }
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) { func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
return nil, nil, err
}
// we found the peer, and we follow a normal login flow
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err = am.Store.GetAccount(account.Id)
if err != nil {
return nil, nil, err
}
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
if err != nil { if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
@@ -603,7 +571,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
} }
// we found the peer, and we follow a normal login flow // we found the peer, and we follow a normal login flow
unlock := am.Store.AcquireAccountLock(account.Id) unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock() defer unlock()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
@@ -760,7 +728,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
return err return err
} }
unlock := am.Store.AcquireAccountLock(account.Id) unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock() defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
@@ -795,7 +763,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -865,6 +833,10 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
return return
} }
for _, peer := range peers { for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})

View File

@@ -13,13 +13,13 @@ type Peer struct {
// ID is an internal ID of the peer // ID is an internal ID of the peer
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"` AccountID string `json:"-" gorm:"index"`
// WireGuard public key // WireGuard public key
Key string `gorm:"index"` Key string `gorm:"index"`
// A setup key this peer was registered with // A setup key this peer was registered with
SetupKey string SetupKey string
// IP address of the Peer // IP address of the Peer
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"` IP net.IP `gorm:"serializer:json"`
// Meta is a Peer system meta data // Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name) // Name is peer's name (machine name)
@@ -61,7 +61,7 @@ type PeerStatus struct { //nolint:revive
// Location is a geo location information of a Peer based on public connection IP // Location is a geo location information of a Peer based on public connection IP
type Location struct { type Location struct {
ConnectionIP net.IP // from grpc peer or reverse proxy headers depends on setup ConnectionIP net.IP `gorm:"serializer:json"` // from grpc peer or reverse proxy headers depends on setup
CountryCode string CountryCode string
CityName string CityName string
GeoNameID uint // city level geoname id GeoNameID uint // city level geoname id

View File

@@ -314,7 +314,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
// GetPolicy from the store // GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) { func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -342,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (
// SavePolicy in the store // SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error { func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -370,7 +370,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
// DeletePolicy from the store // DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error { func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -397,7 +397,7 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
// ListPolicies from the store // ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) { func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)

View File

@@ -7,7 +7,7 @@ import (
) )
func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) { func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -34,7 +34,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us
} }
func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error { func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -81,7 +81,7 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
} }
func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error { func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -113,7 +113,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID,
} }
func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) { func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)

View File

@@ -14,7 +14,7 @@ import (
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -115,7 +115,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
// CreateRoute creates and saves a new route // CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -194,7 +194,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
// SaveRoute saves route // SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error { func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
if routeToSave == nil { if routeToSave == nil {
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
// DeleteRoute deletes route with routeID // DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error { func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -283,7 +283,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
// ListRoutes returns a list of routes from account // ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)

View File

@@ -209,7 +209,7 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty. // and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
keyDuration := DefaultSetupKeyDuration keyDuration := DefaultSetupKeyDuration
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
// (e.g. the key itself, creation date, ID, etc). // (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
if keyToSave == nil { if keyToSave == nil {
@@ -327,7 +327,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
// ListSetupKeys returns a list of all setup keys of the account // ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) { func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
@@ -359,7 +359,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) { func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)

View File

@@ -3,6 +3,8 @@ package server
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
@@ -18,6 +20,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@@ -40,6 +43,8 @@ type installation struct {
InstallationIDValue string InstallationIDValue string
} }
type migrationFunc func(*gorm.DB) error
// NewSqliteStore restores a store from the file located in the datadir // NewSqliteStore restores a store from the file located in the datadir
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) { func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
storeStr := "store.db?cache=shared" storeStr := "store.db?cache=shared"
@@ -50,8 +55,9 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
file := filepath.Join(dataDir, storeStr) file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{ db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent), Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true, CreateBatchSize: 400,
PrepareStmt: true,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -64,13 +70,16 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
conns := runtime.NumCPU() conns := runtime.NumCPU()
sql.SetMaxOpenConns(conns) // TODO: make it configurable sql.SetMaxOpenConns(conns) // TODO: make it configurable
if err := migrate(db); err != nil {
return nil, fmt.Errorf("migrate: %w", err)
}
err = db.AutoMigrate( err = db.AutoMigrate(
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{},
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
) )
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("auto migrate: %w", err)
} }
return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil
@@ -118,17 +127,33 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
return unlock return unlock
} }
func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
log.Tracef("acquiring lock for account %s", accountID) log.Tracef("acquiring write lock for account %s", accountID)
start := time.Now() start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
mtx := value.(*sync.Mutex) mtx := value.(*sync.RWMutex)
mtx.Lock() mtx.Lock()
unlock = func() { unlock = func() {
mtx.Unlock() mtx.Unlock()
log.Tracef("released lock for account %s in %v", accountID, time.Since(start)) log.Tracef("released write lock for account %s in %v", accountID, time.Since(start))
}
return unlock
}
func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
log.Tracef("acquiring read lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.RLock()
unlock = func() {
mtx.RUnlock()
log.Tracef("released read lock for account %s in %v", accountID, time.Since(start))
} }
return unlock return unlock
@@ -188,7 +213,8 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
result = tx. result = tx.
Session(&gorm.Session{FullSaveAssociations: true}). Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account) Clauses(clause.OnConflict{UpdateAll: true}).
Create(account)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
@@ -253,36 +279,43 @@ func (s *SqliteStore) GetInstallationID() string {
} }
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peer nbpeer.Peer var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND id = ?", accountID, peerID).
Updates(peerCopy)
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { return result.Error
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
} }
peer.Status = &peerStatus if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
return s.db.Save(peer).Error return nil
} }
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
var peer nbpeer.Peer // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerWithLocation.ID) var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
Updates(peerCopy)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { return result.Error
return status.Errorf(status.NotFound, "peer %s not found", peer.ID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
} }
peer.Location = peerWithLocation.Location if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
}
return s.db.Save(peer).Error return nil
} }
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite // DeleteHashedPAT2TokenIDIndex is noop in Sqlite
@@ -390,8 +423,9 @@ func (s *SqliteStore) GetAllAccounts() (all []*Account) {
} }
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
var account Account var account Account
result := s.db.Model(&account). result := s.db.Debug().Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations). Preload(clause.Associations).
First(&account, "id = ?", accountID) First(&account, "id = ?", accountID)
@@ -511,6 +545,21 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
return s.GetAccount(peer.AccountID) return s.GetAccount(peer.AccountID)
} }
func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
}
// SaveUserLastLogin stores the last login time for a user in DB. // SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User var user User
@@ -542,3 +591,36 @@ func (s *SqliteStore) Close() error {
func (s *SqliteStore) GetStoreEngine() StoreEngine { func (s *SqliteStore) GetStoreEngine() StoreEngine {
return SqliteStoreEngine return SqliteStoreEngine
} }
// migrate migrates the SQLite database to the latest schema
func migrate(db *gorm.DB) error {
migrations := getMigrations()
for _, m := range migrations {
if err := m(db); err != nil {
return err
}
}
return nil
}
func getMigrations() []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
},
}
}

View File

@@ -2,16 +2,23 @@ package server
import ( import (
"fmt" "fmt"
"math/rand"
"net" "net"
"net/netip"
"path/filepath" "path/filepath"
"runtime" "runtime"
"testing" "testing"
"time" "time"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -30,6 +37,151 @@ func TestSqlite_NewStore(t *testing.T) {
} }
} }
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
err = store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Peers) != numPerAccount {
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
numPerAccount, len(a.Peers))
return
}
if a != nil && len(a.Users) != numPerAccount+1 {
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
numPerAccount+1, len(a.Users))
return
}
if a != nil && len(a.Routes) != numPerAccount {
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
numPerAccount, len(a.Routes))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
numPerAccount+1, len(a.SetupKeys))
return
}
}
func randomIPv4() net.IP {
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 4)
for i := range b {
b[i] = byte(rand.Intn(256))
}
return net.IP(b)
}
func TestSqlite_SaveAccount(t *testing.T) { func TestSqlite_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")
@@ -349,6 +501,74 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
} }
func TestMigrate(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
err := migrate(store.db)
require.NoError(t, err, "Migration should not fail on empty db")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
type network struct {
Network
Net net.IPNet `gorm:"serializer:gob"`
}
type location struct {
nbpeer.Location
ConnectionIP net.IP
}
type peer struct {
nbpeer.Peer
Location location `gorm:"embedded;embeddedPrefix:location_"`
}
type account struct {
Account
Network *network `gorm:"embedded;embeddedPrefix:network_"`
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
}
act := &account{
Network: &network{
Net: *ipnet,
},
Peers: []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
},
}
err = store.db.Save(act).Error
require.NoError(t, err, "Failed to insert Gob data")
type route struct {
route2.Route
Network netip.Prefix `gorm:"serializer:gob"`
PeerGroups []string `gorm:"serializer:gob"`
}
prefix := netip.MustParsePrefix("11.0.0.0/24")
rt := &route{
Network: prefix,
PeerGroups: []string{"group1", "group2"},
}
err = store.db.Save(rt).Error
require.NoError(t, err, "Failed to insert Gob data")
err = migrate(store.db)
require.NoError(t, err, "Migration should not fail on gob populated db")
err = migrate(store.db)
require.NoError(t, err, "Migration should not fail on migrated db")
}
func newSqliteStore(t *testing.T) *SqliteStore { func newSqliteStore(t *testing.T) *SqliteStore {
t.Helper() t.Helper()

View File

@@ -19,6 +19,7 @@ type Store interface {
DeleteAccount(account *Account) error DeleteAccount(account *Account) error
GetAccountByUser(userID string) (*Account, error) GetAccountByUser(userID string) (*Account, error)
GetAccountByPeerPubKey(peerKey string) (*Account, error) GetAccountByPeerPubKey(peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(peerKey string) (string, error)
GetAccountByPeerID(peerID string) (*Account, error) GetAccountByPeerID(peerID string) (*Account, error)
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(domain string) (*Account, error) GetAccountByPrivateDomain(domain string) (*Account, error)
@@ -29,8 +30,10 @@ type Store interface {
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string GetInstallationID() string
SaveInstallationID(ID string) error SaveInstallationID(ID string) error
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock // AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock
AcquireAccountLock(accountID string) func() AcquireAccountWriteLock(accountID string) func()
// AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock
AcquireAccountReadLock(accountID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock() func() AcquireGlobalLock() func()
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error

View File

@@ -210,7 +210,7 @@ func NewOwnerUser(id string) *User {
// createServiceUser creates a new service user under the given account. // createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -266,7 +266,7 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *User
// inviteNewUser Invites a USer to a given account and creates reference in datastore // inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) { func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
if am.idpManager == nil { if am.idpManager == nil {
@@ -367,7 +367,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
return nil, fmt.Errorf("failed to get account with token claims %v", err) return nil, fmt.Errorf("failed to get account with token claims %v", err)
} }
unlock := am.Store.AcquireAccountLock(account.Id) unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock() defer unlock()
account, err = am.Store.GetAccount(account.Id) account, err = am.Store.GetAccount(account.Id)
@@ -400,7 +400,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
// ListUsers returns lists of all users under the account. // ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name. // It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) { func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
if initiatorUserID == targetUserID { if initiatorUserID == targetUserID {
return status.Errorf(status.InvalidArgument, "self deletion is not allowed") return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
} }
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -537,7 +537,7 @@ func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetU
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error { func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
if am.idpManager == nil { if am.idpManager == nil {
@@ -577,7 +577,7 @@ func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID st
// CreatePAT creates a new PAT for the given user // CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
if tokenName == "" { if tokenName == "" {
@@ -627,7 +627,7 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str
// DeletePAT deletes a specific PAT from a user // DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -677,7 +677,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str
// GetPAT returns a specific PAT from a user // GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -709,7 +709,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string
// GetAllPATs returns all PATs for a user // GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
@@ -747,7 +747,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountWriteLock(accountID)
defer unlock() defer unlock()
if update == nil { if update == nil {
@@ -960,7 +960,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
queriedUsers := make([]*idp.UserData, 0) queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
users := make(map[string]struct{}, len(account.Users)) users := make(map[string]userLoggedInOnce, len(account.Users))
usersFromIntegration := make([]*idp.UserData, 0) usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range account.Users { for _, user := range account.Users {
if user.Issued == UserIssuedIntegration { if user.Issued == UserIssuedIntegration {
@@ -968,14 +968,14 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
info, err := am.externalCacheManager.Get(am.ctx, key) info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil { if err != nil {
log.Infof("Get ExternalCache for key: %s, error: %s", key, err) log.Infof("Get ExternalCache for key: %s, error: %s", key, err)
users[user.Id] = struct{}{} users[user.Id] = true
continue continue
} }
usersFromIntegration = append(usersFromIntegration, info) usersFromIntegration = append(usersFromIntegration, info)
continue continue
} }
if !user.IsServiceUser { if !user.IsServiceUser {
users[user.Id] = struct{}{} users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
} }
} }
queriedUsers, err = am.lookupCache(users, accountID) queriedUsers, err = am.lookupCache(users, accountID)

View File

@@ -68,11 +68,11 @@ type Route struct {
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
Network netip.Prefix `gorm:"serializer:gob"` Network netip.Prefix `gorm:"serializer:json"`
NetID string NetID string
Description string Description string
Peer string Peer string
PeerGroups []string `gorm:"serializer:gob"` PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType NetworkType NetworkType
Masquerade bool Masquerade bool
Metric int Metric int
@@ -107,6 +107,12 @@ func (r *Route) Copy() *Route {
// IsEqual compares one route with the other // IsEqual compares one route with the other
func (r *Route) IsEqual(other *Route) bool { func (r *Route) IsEqual(other *Route) bool {
if r == nil && other == nil {
return true
} else if r == nil || other == nil {
return false
}
return other.ID == r.ID && return other.ID == r.ID &&
other.Description == r.Description && other.Description == r.Description &&
other.NetID == r.NetID && other.NetID == r.NetID &&