mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-08 09:49:54 +00:00
Compare commits
78 Commits
v0.30.3
...
add-static
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75a02c3e38 | ||
|
|
052525093a | ||
|
|
b4b9aedf5a | ||
|
|
07f0f9fdbd | ||
|
|
c6641be94b | ||
|
|
89cf8a55e2 | ||
|
|
00c3b67182 | ||
|
|
9203690033 | ||
|
|
9683da54b0 | ||
|
|
0e48a772ff | ||
|
|
f118d81d32 | ||
|
|
ca12bc6953 | ||
|
|
9810386937 | ||
|
|
f1625b32bd | ||
|
|
0ecd5f2118 | ||
|
|
940d0c48c6 | ||
|
|
56cecf849e | ||
|
|
05c4aa7c2c | ||
|
|
2a5cb16494 | ||
|
|
9db1932664 | ||
|
|
1bbabf70b0 | ||
|
|
aa575d6f44 | ||
|
|
f66bbcc54c | ||
|
|
5dd6a08ea6 | ||
|
|
eb5d0569ae | ||
|
|
52ea2e84e9 | ||
|
|
78fab877c0 | ||
|
|
65a94f695f | ||
|
|
ec543f89fb | ||
|
|
a7d5c52203 | ||
|
|
582bb58714 | ||
|
|
121dfda915 | ||
|
|
a1c5287b7c | ||
|
|
12f442439a | ||
|
|
d9b691b8a5 | ||
|
|
4aee3c9e33 | ||
|
|
44e799c687 | ||
|
|
be78efbd42 | ||
|
|
6886691213 | ||
|
|
b48afd92fd | ||
|
|
39329e12a1 | ||
|
|
20a5afc359 | ||
|
|
6cb697eed6 | ||
|
|
e0bed2b0fb | ||
|
|
30f025e7dd | ||
|
|
b4d7605147 | ||
|
|
08b6e9d647 | ||
|
|
67ce14eaea | ||
|
|
669904cd06 | ||
|
|
4be826450b | ||
|
|
738387f2de | ||
|
|
baf0678ceb | ||
|
|
7fef8f6758 | ||
|
|
6829a64a2d | ||
|
|
cbf500024f | ||
|
|
509e184e10 | ||
|
|
3e88b7c56e | ||
|
|
b952d8693d | ||
|
|
5b46cc8e9c | ||
|
|
a9d06b883f | ||
|
|
5f06b202c3 | ||
|
|
0eb99c266a | ||
|
|
bac95ace18 | ||
|
|
9812de853b | ||
|
|
ad4f0a6fdf | ||
|
|
4c758c6e52 | ||
|
|
ec5095ba6b | ||
|
|
49a54624f8 | ||
|
|
729bcf2b01 | ||
|
|
a0cdb58303 | ||
|
|
39c99781cb | ||
|
|
01f24907c5 | ||
|
|
10480eb52f | ||
|
|
1e44c5b574 | ||
|
|
940f8b4547 | ||
|
|
46e37fa04c | ||
|
|
b9f205b2ce | ||
|
|
0fd874fa45 |
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# These are supported funding model platforms
|
||||||
|
|
||||||
|
github: [netbirdio]
|
||||||
42
.github/workflows/golang-test-linux.yml
vendored
42
.github/workflows/golang-test-linux.yml
vendored
@@ -13,6 +13,7 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
store: [ 'sqlite', 'postgres']
|
store: [ 'sqlite', 'postgres']
|
||||||
@@ -51,6 +52,47 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||||
|
|
||||||
|
benchmark:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/go/pkg/mod
|
||||||
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.16"
|
SIGN_PIPE_VER: "v0.0.17"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||||
|
|||||||
@@ -17,8 +17,12 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
|
</a>
|
||||||
|
<br>
|
||||||
|
<a href="https://gurubase.io/g/netbird">
|
||||||
|
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
@@ -30,7 +34,7 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
|
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder, &system.StaticInfo{})
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,7 +132,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder, &system.StaticInfo{})
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
err := a.withBackOff(a.ctx, func() error {
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "", &system.StaticInfo{})
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
// we got an answer from management, exit backoff earlier
|
// we got an answer from management, exit backoff earlier
|
||||||
return backoff.Permanent(backoffErr)
|
return backoff.Permanent(backoffErr)
|
||||||
@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey, &system.StaticInfo{})
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -179,7 +179,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err = a.withBackOff(a.ctx, func() error {
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
err := internal.Login(a.ctx, a.config, "", jwtToken, &system.StaticInfo{})
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -201,6 +201,8 @@ func isWellKnown(addr netip.Addr) bool {
|
|||||||
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
|
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
|
||||||
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
|
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
|
||||||
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
|
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
|
||||||
|
|
||||||
|
"128.0.0.0", "8000::", // 2nd split subnet for default routes
|
||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(wellKnown, addr.String()) {
|
if slices.Contains(wellKnown, addr.String()) {
|
||||||
|
|||||||
@@ -137,9 +137,11 @@ var loginCmd = &cobra.Command{
|
|||||||
|
|
||||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.Config, setupKey string) error {
|
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.Config, setupKey string) error {
|
||||||
needsLogin := false
|
needsLogin := false
|
||||||
|
staticInfoChan := system.GetStaticInfoInBackground(ctx)
|
||||||
|
staticInfo := <-staticInfoChan
|
||||||
|
|
||||||
err := WithBackOff(func() error {
|
err := WithBackOff(func() error {
|
||||||
err := internal.Login(ctx, config, "", "")
|
err := internal.Login(ctx, config, "", "", staticInfo)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
needsLogin = true
|
needsLogin = true
|
||||||
return nil
|
return nil
|
||||||
@@ -162,7 +164,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
var lastError error
|
var lastError error
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
err = WithBackOff(func() error {
|
||||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
err := internal.Login(ctx, config, setupKey, jwtToken, staticInfo)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
lastError = err
|
lastError = err
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -13,10 +14,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type program struct {
|
type program struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
serv *grpc.Server
|
serv *grpc.Server
|
||||||
serverInstance *server.Server
|
serverInstance *server.Server
|
||||||
|
serverInstanceMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||||
|
|||||||
@@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||||
|
|
||||||
|
p.serverInstanceMu.Lock()
|
||||||
p.serverInstance = serverInstance
|
p.serverInstance = serverInstance
|
||||||
|
p.serverInstanceMu.Unlock()
|
||||||
|
|
||||||
log.Printf("started daemon server: %v", split[1])
|
log.Printf("started daemon server: %v", split[1])
|
||||||
if err := p.serv.Serve(listen); err != nil {
|
if err := p.serv.Serve(listen); err != nil {
|
||||||
@@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *program) Stop(srv service.Service) error {
|
func (p *program) Stop(srv service.Service) error {
|
||||||
|
p.serverInstanceMu.Lock()
|
||||||
if p.serverInstance != nil {
|
if p.serverInstance != nil {
|
||||||
in := new(proto.DownRequest)
|
in := new(proto.DownRequest)
|
||||||
_, err := p.serverInstance.Down(p.ctx, in)
|
_, err := p.serverInstance.Down(p.ctx, in)
|
||||||
@@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
|
|||||||
log.Errorf("failed to stop daemon: %v", err)
|
log.Errorf("failed to stop daemon: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
p.serverInstanceMu.Unlock()
|
||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||||
statusEval := false
|
statusEval := false
|
||||||
ipEval := false
|
ipEval := false
|
||||||
nameEval := false
|
nameEval := true
|
||||||
|
|
||||||
if statusFilter != "" {
|
if statusFilter != "" {
|
||||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||||
@@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
|||||||
|
|
||||||
if len(prefixNamesFilter) > 0 {
|
if len(prefixNamesFilter) > 0 {
|
||||||
for prefixNameFilter := range prefixNamesFilterMap {
|
for prefixNameFilter := range prefixNamesFilterMap {
|
||||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||||
nameEval = true
|
nameEval = false
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
nameEval = false
|
||||||
}
|
}
|
||||||
|
|
||||||
return statusEval || ipEval || nameEval
|
return statusEval || ipEval || nameEval
|
||||||
|
|||||||
@@ -152,6 +152,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
staticInfoChan := system.GetStaticInfoInBackground(ctx)
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(ic)
|
config, err := internal.UpdateOrCreateConfig(ic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get config file: %v", err)
|
return fmt.Errorf("get config file: %v", err)
|
||||||
@@ -171,7 +173,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
r := peer.NewRecorder(config.ManagementURL.String())
|
r := peer.NewRecorder(config.ManagementURL.String())
|
||||||
r.GetFullStatus()
|
r.GetFullStatus()
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
connectClient := internal.NewConnectClient(ctx, config, r, <-staticInfoChan)
|
||||||
return connectClient.Run()
|
return connectClient.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@@ -37,62 +38,55 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
|||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
// for the userspace packet filtering firewall
|
// for the userspace packet filtering firewall
|
||||||
fm, errFw := createNativeFirewall(iface)
|
fm, err := createNativeFirewall(iface, stateManager)
|
||||||
|
|
||||||
if fm != nil {
|
if !iface.IsUserspaceBind() {
|
||||||
if err := fm.Init(stateManager); err != nil {
|
return fm, err
|
||||||
log.Errorf("failed to init nftables manager: %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if iface.IsUserspaceBind() {
|
if err != nil {
|
||||||
return createUserspaceFirewall(iface, fm, errFw)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
|
return createUserspaceFirewall(iface, fm)
|
||||||
return fm, errFw
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||||
|
fm, err := createFW(iface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create firewall: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = fm.Init(stateManager); err != nil {
|
||||||
|
return nil, fmt.Errorf("init firewall: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||||
switch check() {
|
switch check() {
|
||||||
case IPTABLES:
|
case IPTABLES:
|
||||||
return createIptablesFirewall(iface)
|
log.Info("creating an iptables firewall manager")
|
||||||
|
return nbiptables.Create(iface)
|
||||||
case NFTABLES:
|
case NFTABLES:
|
||||||
return createNftablesFirewall(iface)
|
log.Info("creating an nftables firewall manager")
|
||||||
|
return nbnftables.Create(iface)
|
||||||
default:
|
default:
|
||||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||||
return nil, fmt.Errorf("no firewall manager found")
|
return nil, errors.New("no firewall manager found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
||||||
log.Info("creating an iptables firewall manager")
|
|
||||||
fm, err := nbiptables.Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create iptables manager: %s", err)
|
|
||||||
}
|
|
||||||
return fm, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) {
|
|
||||||
log.Info("creating an nftables firewall manager")
|
|
||||||
fm, err := nbnftables.Create(iface)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create nftables manager: %s", err)
|
|
||||||
}
|
|
||||||
return fm, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) {
|
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if errFw == nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface)
|
fm, errUsp = uspfilter.Create(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
|
||||||
return nil, errUsp
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fm.AllowNetbird(); err != nil {
|
if err := fm.AllowNetbird(); err != nil {
|
||||||
|
|||||||
@@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
m.optionalEntries["FORWARD"] = []entry{
|
m.optionalEntries["FORWARD"] = []entry{
|
||||||
{
|
{
|
||||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
|
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
||||||
position: 2,
|
position: 2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m.optionalEntries["PREROUTING"] = []entry{
|
m.optionalEntries["PREROUTING"] = []entry{
|
||||||
{
|
{
|
||||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
|
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
||||||
position: 1,
|
position: 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
go func() {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
}
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,22 +18,24 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4Nat = "netbird-rt-nat"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
// constants needed to manage and create iptable rules
|
||||||
const (
|
const (
|
||||||
tableFilter = "filter"
|
tableFilter = "filter"
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
|
tableMangle = "mangle"
|
||||||
chainPOSTROUTING = "POSTROUTING"
|
chainPOSTROUTING = "POSTROUTING"
|
||||||
|
chainPREROUTING = "PREROUTING"
|
||||||
chainRTNAT = "NETBIRD-RT-NAT"
|
chainRTNAT = "NETBIRD-RT-NAT"
|
||||||
chainRTFWD = "NETBIRD-RT-FWD"
|
chainRTFWD = "NETBIRD-RT-FWD"
|
||||||
|
chainRTPRE = "NETBIRD-RT-PRE"
|
||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
|
|
||||||
|
jumpPre = "jump-pre"
|
||||||
|
jumpNat = "jump-nat"
|
||||||
matchSet = "--match-set"
|
matchSet = "--match-set"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -296,6 +298,8 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
}
|
}
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,24 +325,25 @@ func (r *router) Reset() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) cleanUpDefaultForwardRules() error {
|
func (r *router) cleanUpDefaultForwardRules() error {
|
||||||
err := r.cleanJumpRules()
|
if err := r.cleanJumpRules(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("clean jump rules: %w", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("flushing routing related tables")
|
log.Debug("flushing routing related tables")
|
||||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
for _, chainInfo := range []struct {
|
||||||
table := r.getTableForChain(chain)
|
chain string
|
||||||
|
table string
|
||||||
ok, err := r.iptablesClient.ChainExists(table, chain)
|
}{
|
||||||
|
{chainRTFWD, tableFilter},
|
||||||
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTPRE, tableMangle},
|
||||||
|
} {
|
||||||
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed check chain %s, error: %v", chain, err)
|
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
return err
|
|
||||||
} else if ok {
|
} else if ok {
|
||||||
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
|
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -347,9 +352,16 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createContainers() error {
|
func (r *router) createContainers() error {
|
||||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
for _, chainInfo := range []struct {
|
||||||
if err := r.createAndSetupChain(chain); err != nil {
|
chain string
|
||||||
return fmt.Errorf("create chain %s: %w", chain, err)
|
table string
|
||||||
|
}{
|
||||||
|
{chainRTFWD, tableFilter},
|
||||||
|
{chainRTPRE, tableMangle},
|
||||||
|
{chainRTNAT, tableNat},
|
||||||
|
} {
|
||||||
|
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
||||||
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -357,6 +369,10 @@ func (r *router) createContainers() error {
|
|||||||
return fmt.Errorf("insert established rule: %w", err)
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.addPostroutingRules(); err != nil {
|
||||||
|
return fmt.Errorf("add static nat rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.addJumpRules(); err != nil {
|
if err := r.addJumpRules(); err != nil {
|
||||||
return fmt.Errorf("add jump rules: %w", err)
|
return fmt.Errorf("add jump rules: %w", err)
|
||||||
}
|
}
|
||||||
@@ -364,6 +380,32 @@ func (r *router) createContainers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) addPostroutingRules() error {
|
||||||
|
// First rule for outbound masquerade
|
||||||
|
rule1 := []string{
|
||||||
|
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
|
"!", "-o", "lo",
|
||||||
|
"-j", routingFinalNatJump,
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
||||||
|
return fmt.Errorf("add outbound masquerade rule: %v", err)
|
||||||
|
}
|
||||||
|
r.rules["static-nat-outbound"] = rule1
|
||||||
|
|
||||||
|
// Second rule for return traffic masquerade
|
||||||
|
rule2 := []string{
|
||||||
|
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-j", routingFinalNatJump,
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
||||||
|
return fmt.Errorf("add return masquerade rule: %v", err)
|
||||||
|
}
|
||||||
|
r.rules["static-nat-return"] = rule2
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) createAndSetupChain(chain string) error {
|
func (r *router) createAndSetupChain(chain string) error {
|
||||||
table := r.getTableForChain(chain)
|
table := r.getTableForChain(chain)
|
||||||
|
|
||||||
@@ -375,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) getTableForChain(chain string) string {
|
func (r *router) getTableForChain(chain string) string {
|
||||||
if chain == chainRTNAT {
|
switch chain {
|
||||||
|
case chainRTNAT:
|
||||||
return tableNat
|
return tableNat
|
||||||
|
case chainRTPRE:
|
||||||
|
return tableMangle
|
||||||
|
default:
|
||||||
|
return tableFilter
|
||||||
}
|
}
|
||||||
return tableFilter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) insertEstablishedRule(chain string) error {
|
func (r *router) insertEstablishedRule(chain string) error {
|
||||||
@@ -396,25 +442,39 @@ func (r *router) insertEstablishedRule(chain string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) addJumpRules() error {
|
func (r *router) addJumpRules() error {
|
||||||
rule := []string{"-j", chainRTNAT}
|
// Jump to NAT chain
|
||||||
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
natRule := []string{"-j", chainRTNAT}
|
||||||
if err != nil {
|
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||||
return err
|
return fmt.Errorf("add nat jump rule: %v", err)
|
||||||
}
|
}
|
||||||
r.rules[ipv4Nat] = rule
|
r.rules[jumpNat] = natRule
|
||||||
|
|
||||||
|
// Jump to prerouting chain
|
||||||
|
preRule := []string{"-j", chainRTPRE}
|
||||||
|
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||||
|
return fmt.Errorf("add prerouting jump rule: %v", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpPre] = preRule
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) cleanJumpRules() error {
|
func (r *router) cleanJumpRules() error {
|
||||||
rule, found := r.rules[ipv4Nat]
|
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
||||||
if found {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
table := tableNat
|
||||||
if err != nil {
|
chain := chainPOSTROUTING
|
||||||
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
|
if ruleKey == jumpPre {
|
||||||
|
table = tableMangle
|
||||||
|
chain = chainPREROUTING
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||||
|
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -422,19 +482,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
|
markValue := nbnet.PreroutingFwmarkMasquerade
|
||||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
|
if pair.Inverse {
|
||||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := []string{"-i", r.wgIface.Name()}
|
||||||
|
if pair.Inverse {
|
||||||
|
rule = []string{"!", "-i", r.wgIface.Name()}
|
||||||
|
}
|
||||||
|
|
||||||
|
rule = append(rule,
|
||||||
|
"-m", "conntrack",
|
||||||
|
"--ctstate", "NEW",
|
||||||
|
"-s", pair.Source.String(),
|
||||||
|
"-d", pair.Destination.String(),
|
||||||
|
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
||||||
|
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = rule
|
r.rules[ruleKey] = rule
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -442,13 +518,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||||
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("nat rule %s not found", ruleKey)
|
log.Debugf("marking rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -480,16 +555,6 @@ func (r *router) updateState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
|
||||||
intdir := "-i"
|
|
||||||
lointdir := "-o"
|
|
||||||
if inverse {
|
|
||||||
intdir = "-o"
|
|
||||||
lointdir = "-i"
|
|
||||||
}
|
|
||||||
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
|
|
||||||
}
|
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||||
var rule []string
|
var rule []string
|
||||||
|
|
||||||
|
|||||||
@@ -3,17 +3,18 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isIptablesSupported() bool {
|
func isIptablesSupported() bool {
|
||||||
@@ -34,14 +35,24 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = manager.Reset()
|
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
// Now 5 rules:
|
||||||
|
// 1. established rule in forward chain
|
||||||
|
// 2. jump rule to NAT chain
|
||||||
|
// 3. jump rule to PRE chain
|
||||||
|
// 4. static outbound masquerade rule
|
||||||
|
// 5. static return masquerade rule
|
||||||
|
require.Len(t, manager.rules, 5, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
require.True(t, exists, "postrouting rule should exist")
|
require.True(t, exists, "postrouting jump rule should exist")
|
||||||
|
|
||||||
|
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
||||||
|
require.True(t, exists, "prerouting jump rule should exist")
|
||||||
|
|
||||||
pair := firewall.RouterPair{
|
pair := firewall.RouterPair{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
@@ -49,22 +60,15 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
err = manager.AddNatRule(pair)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "adding NAT rule should not return error")
|
||||||
|
|
||||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
|
|
||||||
|
|
||||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManager_AddNatRule(t *testing.T) {
|
func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
if !isIptablesSupported() {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
@@ -79,52 +83,66 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset()
|
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to reset iptables manager: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = manager.AddNatRule(testCase.InputPair)
|
err = manager.AddNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "forwarding pair should be inserted")
|
require.NoError(t, err, "marking rule should be inserted")
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
markingRule := []string{
|
||||||
|
"-i", ifaceMock.Name(),
|
||||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
"-m", "conntrack",
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
"--ctstate", "NEW",
|
||||||
if testCase.InputPair.Masquerade {
|
"-s", testCase.InputPair.Source.String(),
|
||||||
require.True(t, exists, "nat rule should be created")
|
"-d", testCase.InputPair.Destination.String(),
|
||||||
foundNatRule, foundNat := manager.rules[natRuleKey]
|
"-j", "MARK", "--set-mark",
|
||||||
require.True(t, foundNat, "nat rule should exist in the map")
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
|
||||||
} else {
|
|
||||||
require.False(t, exists, "nat rule should not be created")
|
|
||||||
_, foundNat := manager.rules[natRuleKey]
|
|
||||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
require.True(t, exists, "income nat rule should be created")
|
require.True(t, exists, "marking rule should be created")
|
||||||
foundNatRule, foundNat := manager.rules[inNatRuleKey]
|
foundRule, found := manager.rules[natRuleKey]
|
||||||
require.True(t, foundNat, "income nat rule should exist in the map")
|
require.True(t, found, "marking rule should exist in the map")
|
||||||
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
require.Equal(t, markingRule, foundRule, "stored marking rule should match")
|
||||||
} else {
|
} else {
|
||||||
require.False(t, exists, "nat rule should not be created")
|
require.False(t, exists, "marking rule should not be created")
|
||||||
_, foundNat := manager.rules[inNatRuleKey]
|
_, found := manager.rules[natRuleKey]
|
||||||
require.False(t, foundNat, "income nat rule should not exist in the map")
|
require.False(t, found, "marking rule should not exist in the map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check inverse rule
|
||||||
|
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||||
|
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||||
|
inverseMarkingRule := []string{
|
||||||
|
"!", "-i", ifaceMock.Name(),
|
||||||
|
"-m", "conntrack",
|
||||||
|
"--ctstate", "NEW",
|
||||||
|
"-s", inversePair.Source.String(),
|
||||||
|
"-d", inversePair.Destination.String(),
|
||||||
|
"-j", "MARK", "--set-mark",
|
||||||
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
|
if testCase.InputPair.Masquerade {
|
||||||
|
require.True(t, exists, "inverse marking rule should be created")
|
||||||
|
foundRule, found := manager.rules[inverseRuleKey]
|
||||||
|
require.True(t, found, "inverse marking rule should exist in the map")
|
||||||
|
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
|
||||||
|
} else {
|
||||||
|
require.False(t, exists, "inverse marking rule should not be created")
|
||||||
|
_, found := manager.rules[inverseRuleKey]
|
||||||
|
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||||
|
|
||||||
if !isIptablesSupported() {
|
if !isIptablesSupported() {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
@@ -137,42 +155,52 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = manager.Reset()
|
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
err = manager.AddNatRule(testCase.InputPair)
|
||||||
|
require.NoError(t, err, "should add NAT rule without error")
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
|
||||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
|
||||||
|
|
||||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
|
||||||
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.RemoveNatRule(testCase.InputPair)
|
err = manager.RemoveNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
markingRule := []string{
|
||||||
require.False(t, exists, "nat rule should not exist")
|
"-i", ifaceMock.Name(),
|
||||||
|
"-m", "conntrack",
|
||||||
|
"--ctstate", "NEW",
|
||||||
|
"-s", testCase.InputPair.Source.String(),
|
||||||
|
"-d", testCase.InputPair.Destination.String(),
|
||||||
|
"-j", "MARK", "--set-mark",
|
||||||
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
|
require.False(t, exists, "marking rule should not exist")
|
||||||
|
|
||||||
_, found := manager.rules[natRuleKey]
|
_, found := manager.rules[natRuleKey]
|
||||||
require.False(t, found, "nat rule should exist in the manager map")
|
require.False(t, found, "marking rule should not exist in the manager map")
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
// Check inverse rule removal
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||||
require.False(t, exists, "income nat rule should not exist")
|
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||||
|
inverseMarkingRule := []string{
|
||||||
|
"!", "-i", ifaceMock.Name(),
|
||||||
|
"-m", "conntrack",
|
||||||
|
"--ctstate", "NEW",
|
||||||
|
"-s", inversePair.Source.String(),
|
||||||
|
"-d", inversePair.Destination.String(),
|
||||||
|
"-j", "MARK", "--set-mark",
|
||||||
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
}
|
||||||
|
|
||||||
_, found = manager.rules[inNatRuleKey]
|
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||||
require.False(t, found, "income nat rule should exist in the manager map")
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
|
require.False(t, exists, "inverse marking rule should not exist")
|
||||||
|
|
||||||
|
_, found = manager.rules[inverseRuleKey]
|
||||||
|
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
ForwardingFormatPrefix = "netbird-fwd-"
|
ForwardingFormatPrefix = "netbird-fwd-"
|
||||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||||
|
PreroutingFormat = "netbird-prerouting-%s-%t"
|
||||||
NatFormat = "netbird-nat-%s-%t"
|
NatFormat = "netbird-nat-%s-%t"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -520,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
|||||||
},
|
},
|
||||||
&expr.Immediate{
|
&expr.Immediate{
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
},
|
},
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
Key: expr.MetaKeyMARK,
|
Key: expr.MetaKeyMARK,
|
||||||
@@ -543,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
},
|
},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictJump,
|
Kind: expr.VerdictJump,
|
||||||
|
|||||||
@@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// persist early
|
// persist early
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
go func() {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
}
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
|
|
||||||
var chain *nftables.Chain
|
var chain *nftables.Chain
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||||
chain = c
|
chain = c
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -230,23 +232,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
|
|
||||||
// SetLegacyManagement sets the route manager to use legacy management
|
// SetLegacyManagement sets the route manager to use legacy management
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
oldLegacy := m.router.legacyManagement
|
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||||
|
|
||||||
if oldLegacy != isLegacy {
|
|
||||||
m.router.legacyManagement = isLegacy
|
|
||||||
log.Debugf("Set legacy management to %v", isLegacy)
|
|
||||||
}
|
|
||||||
|
|
||||||
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
|
|
||||||
if !isLegacy && oldLegacy {
|
|
||||||
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
|
|
||||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Legacy routing rules removed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
@@ -290,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
|
|||||||
|
|
||||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||||
rules, err := m.rConn.GetRules(c.Table, c)
|
rules, err := m.rConn.GetRules(c.Table, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||||
@@ -365,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(m.wgIface.Name()),
|
Data: ifname(m.wgIface.Name()),
|
||||||
},
|
},
|
||||||
&expr.Verdict{},
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
UserData: []byte(allowNetbirdInputRuleID),
|
UserData: []byte(allowNetbirdInputRuleID),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runIptablesSave(t *testing.T) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd := exec.Command("iptables-save")
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err, "iptables-save failed to run")
|
||||||
|
|
||||||
|
return stdout.String(), stderr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
|
||||||
|
t.Helper()
|
||||||
|
// Check for any incompatibility warnings
|
||||||
|
require.NotContains(t,
|
||||||
|
stderr,
|
||||||
|
"incompatible",
|
||||||
|
"iptables-save produced compatibility warning. Full stderr: %s",
|
||||||
|
stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify standard tables are present
|
||||||
|
expectedTables := []string{
|
||||||
|
"*filter",
|
||||||
|
"*nat",
|
||||||
|
"*mangle",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range expectedTables {
|
||||||
|
require.Contains(t,
|
||||||
|
stdout,
|
||||||
|
table,
|
||||||
|
"iptables-save output missing expected table: %s\nFull stdout: %s",
|
||||||
|
table,
|
||||||
|
stdout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||||
|
t.Skipf("iptables-save not available on this system: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First ensure iptables-nft tables exist by running iptables-save
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock)
|
||||||
|
require.NoError(t, err, "failed to create manager")
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := manager.Reset(nil)
|
||||||
|
require.NoError(t, err, "failed to reset manager state")
|
||||||
|
|
||||||
|
// Verify iptables output after reset
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := net.ParseIP("100.96.0.1")
|
||||||
|
_, err = manager.AddPeerFiltering(
|
||||||
|
ip,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []int{80}},
|
||||||
|
fw.RuleDirectionIN,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
"test rule",
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
|
_, err = manager.AddRouteFiltering(
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []int{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|
||||||
|
pair := fw.RouterPair{
|
||||||
|
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
err = manager.AddNatRule(pair)
|
||||||
|
require.NoError(t, err, "failed to add NAT rule")
|
||||||
|
|
||||||
|
stdout, stderr = runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -124,7 +125,6 @@ func (r *router) createContainers() error {
|
|||||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||||
|
|
||||||
prio := *nftables.ChainPriorityNATSource - 1
|
prio := *nftables.ChainPriorityNATSource - 1
|
||||||
|
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||||
Name: chainNameRoutingNat,
|
Name: chainNameRoutingNat,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
@@ -133,6 +133,21 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Chain is created by acl manager
|
||||||
|
// TODO: move creation to a common place
|
||||||
|
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||||
|
Name: chainNamePrerouting,
|
||||||
|
Table: r.workTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the single NAT rule that matches on mark
|
||||||
|
if err := r.addPostroutingRules(); err != nil {
|
||||||
|
return fmt.Errorf("add single nat rule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.acceptForwardRules(); err != nil {
|
if err := r.acceptForwardRules(); err != nil {
|
||||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
}
|
}
|
||||||
@@ -422,59 +437,149 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||||
|
|
||||||
dir := expr.MetaKeyIIFNAME
|
op := expr.CmpOpEq
|
||||||
notDir := expr.MetaKeyOIFNAME
|
|
||||||
if pair.Inverse {
|
if pair.Inverse {
|
||||||
dir = expr.MetaKeyOIFNAME
|
op = expr.CmpOpNeq
|
||||||
notDir = expr.MetaKeyIIFNAME
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lo := ifname("lo")
|
|
||||||
intf := ifname(r.wgIface.Name())
|
|
||||||
|
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
&expr.Meta{
|
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||||
Key: dir,
|
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
},
|
},
|
||||||
&expr.Cmp{
|
&expr.Bitwise{
|
||||||
Op: expr.CmpOpEq,
|
SourceRegister: 1,
|
||||||
Register: 1,
|
DestRegister: 1,
|
||||||
Data: intf,
|
Len: 4,
|
||||||
},
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
|
||||||
&expr.Meta{
|
|
||||||
Key: notDir,
|
|
||||||
Register: 1,
|
|
||||||
},
|
},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpNeq,
|
Op: expr.CmpOpNeq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: lo,
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
|
||||||
|
// interface matching
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: op,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs = append(exprs, sourceExp...)
|
exprs = append(exprs, sourceExp...)
|
||||||
exprs = append(exprs, destExp...)
|
exprs = append(exprs, destExp...)
|
||||||
|
|
||||||
|
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
||||||
|
if pair.Inverse {
|
||||||
|
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||||
|
}
|
||||||
|
|
||||||
exprs = append(exprs,
|
exprs = append(exprs,
|
||||||
&expr.Counter{}, &expr.Masq{},
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
SourceRegister: true,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if _, exists := r.rules[ruleKey]; exists {
|
if _, exists := r.rules[ruleKey]; exists {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove routing rule: %w", err)
|
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNameRoutingNat],
|
Chain: r.chains[chainNamePrerouting],
|
||||||
Exprs: exprs,
|
Exprs: exprs,
|
||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPostroutingRules adds the masquerade rules
|
||||||
|
func (r *router) addPostroutingRules() error {
|
||||||
|
// First masquerade rule for traffic coming in from WireGuard interface
|
||||||
|
exprs := []expr.Any{
|
||||||
|
// Match on the first fwmark
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
||||||
|
},
|
||||||
|
|
||||||
|
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname("lo"),
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Masq{},
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: exprs,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Second masquerade rule for traffic going out through WireGuard interface
|
||||||
|
exprs2 := []expr.Any{
|
||||||
|
// Match on the second fwmark
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
|
},
|
||||||
|
|
||||||
|
// Match WireGuard interface
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Masq{},
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: exprs2,
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -551,7 +656,10 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
}
|
}
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
@@ -720,18 +828,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
// RemoveNatRule removes the prerouting mark rule
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
@@ -746,21 +854,20 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
|
|
||||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
err := r.conn.DelRule(rule)
|
err := r.conn.DelRule(rule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
|
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("nftables: nat rule %s not found", ruleKey)
|
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -32,100 +33,87 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
t.Skip("nftables not supported on this OS")
|
t.Skip("nftables not supported on this OS")
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := createWorkTable()
|
|
||||||
require.NoError(t, err, "Failed to create work table")
|
|
||||||
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
for _, testCase := range test.InsertRuleTestCases {
|
for _, testCase := range test.InsertRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := newRouter(table, ifaceMock)
|
// need fw manager to init both acl mgr and router for all chains to be present
|
||||||
require.NoError(t, err, "failed to create router")
|
manager, err := Create(ifaceMock)
|
||||||
require.NoError(t, manager.init(table))
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Reset(nil))
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer func(manager *router) {
|
rtr := manager.router
|
||||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
err = rtr.AddNatRule(testCase.InputPair)
|
||||||
}(manager)
|
|
||||||
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.AddNatRule(testCase.InputPair)
|
|
||||||
require.NoError(t, err, "pair should be inserted")
|
require.NoError(t, err, "pair should be inserted")
|
||||||
|
|
||||||
defer func(manager *router, pair firewall.RouterPair) {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
|
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
|
||||||
}(manager, testCase.InputPair)
|
})
|
||||||
|
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
// Build expected expressions for connection tracking
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
conntrackExprs := []expr.Any{
|
||||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
&expr.Ct{
|
||||||
testingExpression = append(testingExpression,
|
Key: expr.CtKeySTATE,
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: 1,
|
||||||
|
DestRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build interface matching expression
|
||||||
|
ifaceExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(ifaceMock.Name()),
|
Data: ifname(ifaceMock.Name()),
|
||||||
},
|
},
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname("lo"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
|
||||||
found := 0
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
|
||||||
found = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
||||||
}
|
|
||||||
|
|
||||||
if testCase.InputPair.Masquerade {
|
// Build CIDR matching expressions
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
|
||||||
testingExpression = append(testingExpression,
|
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(ifaceMock.Name()),
|
|
||||||
},
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname("lo"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
// Combine all expressions in the correct order
|
||||||
|
// nolint:gocritic
|
||||||
|
testingExpression := append(conntrackExprs, ifaceExprs...)
|
||||||
|
testingExpression = append(testingExpression, sourceExp...)
|
||||||
|
testingExpression = append(testingExpression, destExp...)
|
||||||
|
|
||||||
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := 0
|
found := 0
|
||||||
for _, chain := range manager.chains {
|
for _, chain := range rtr.chains {
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
if chain.Name == chainNamePrerouting {
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
for _, rule := range rules {
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
for _, rule := range rules {
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
found = 1
|
// Compare expressions up to the mark setting expressions
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -135,68 +123,66 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
t.Skip("nftables not supported on this OS")
|
t.Skip("nftables not supported on this OS")
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := createWorkTable()
|
|
||||||
require.NoError(t, err, "Failed to create work table")
|
|
||||||
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
for _, testCase := range test.RemoveRuleTestCases {
|
for _, testCase := range test.RemoveRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := newRouter(table, ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
require.NoError(t, err, "failed to create router")
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.init(table))
|
require.NoError(t, manager.Reset(nil))
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
|
||||||
|
|
||||||
defer func(manager *router) {
|
|
||||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
|
||||||
}(manager)
|
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
|
||||||
|
|
||||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
|
||||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
|
||||||
|
|
||||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
||||||
Table: manager.workTable,
|
|
||||||
Chain: manager.chains[chainNameRoutingNat],
|
|
||||||
Exprs: natExp,
|
|
||||||
UserData: []byte(natRuleKey),
|
|
||||||
})
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
|
rtr := manager.router
|
||||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
|
|
||||||
|
|
||||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
// First add the NAT rule using the router's method
|
||||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
err = rtr.AddNatRule(testCase.InputPair)
|
||||||
|
require.NoError(t, err, "should add NAT rule")
|
||||||
|
|
||||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
// Verify the rule was added
|
||||||
Table: manager.workTable,
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
Chain: manager.chains[chainNameRoutingNat],
|
found := false
|
||||||
Exprs: natExp,
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||||
UserData: []byte(inNatRuleKey),
|
require.NoError(t, err, "should list rules")
|
||||||
})
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
err = nftablesTestingClient.Flush()
|
found = true
|
||||||
require.NoError(t, err, "shouldn't return error")
|
break
|
||||||
|
|
||||||
err = manager.Reset()
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
err = manager.RemoveNatRule(testCase.InputPair)
|
|
||||||
require.NoError(t, err, "shouldn't return error")
|
|
||||||
|
|
||||||
for _, chain := range manager.chains {
|
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 {
|
|
||||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
|
||||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
require.True(t, found, "NAT rule should exist before removal")
|
||||||
|
|
||||||
|
// Now remove the rule
|
||||||
|
err = rtr.RemoveNatRule(testCase.InputPair)
|
||||||
|
require.NoError(t, err, "shouldn't return error when removing rule")
|
||||||
|
|
||||||
|
// Verify the rule was removed
|
||||||
|
found = false
|
||||||
|
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||||
|
require.NoError(t, err, "should list rules after removal")
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.False(t, found, "NAT rule should not exist after removal")
|
||||||
|
|
||||||
|
// Verify the static postrouting rules still exist
|
||||||
|
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
|
||||||
|
require.NoError(t, err, "should list postrouting rules")
|
||||||
|
foundCounter := false
|
||||||
|
for _, rule := range rules {
|
||||||
|
for _, e := range rule.Exprs {
|
||||||
|
if _, ok := e.(*expr.Counter); ok {
|
||||||
|
foundCounter = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if foundCounter {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.True(t, foundCounter, "static postrouting rule should remain")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -237,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||||
func (m *Manager) SetLegacyManagement(_ bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
return nil
|
if m.nativeFirewall == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
|||||||
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// ControlFns is not thread safe and should only be modified during init.
|
||||||
|
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,8 +25,8 @@ type receiverCreator struct {
|
|||||||
iceBind *ICEBind
|
iceBind *ICEBind
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICEBind is a bind implementation with two main features:
|
// ICEBind is a bind implementation with two main features:
|
||||||
@@ -154,7 +155,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||||
s.muUDPMux.Lock()
|
s.muUDPMux.Lock()
|
||||||
defer s.muUDPMux.Unlock()
|
defer s.muUDPMux.Unlock()
|
||||||
|
|
||||||
@@ -166,16 +167,30 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
msgs := getMessages(msgsPool)
|
||||||
defer ipv4MsgsPool.Put(msgs)
|
|
||||||
for i := range bufs {
|
for i := range bufs {
|
||||||
(*msgs)[i].Buffers[0] = bufs[i]
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||||
}
|
}
|
||||||
|
defer putMessages(msgs, msgsPool)
|
||||||
var numMsgs int
|
var numMsgs int
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
if rxOffload {
|
||||||
if err != nil {
|
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||||
return 0, err
|
//nolint
|
||||||
|
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
msg := &(*msgs)[0]
|
msg := &(*msgs)[0]
|
||||||
@@ -191,11 +206,12 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
|||||||
// todo: handle err
|
// todo: handle err
|
||||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||||
if ok {
|
if ok {
|
||||||
sizes[i] = 0
|
continue
|
||||||
} else {
|
}
|
||||||
sizes[i] = msg.N
|
sizes[i] = msg.N
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||||
@@ -273,3 +289,15 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
|||||||
}
|
}
|
||||||
return newAddr, nil
|
return newAddr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||||
|
return msgsPool.Get().(*[]ipv6.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||||
|
for i := range *msgs {
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||||
|
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||||
|
}
|
||||||
|
msgsPool.Put(msgs)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package bind
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {
|
|||||||
|
|
||||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||||
|
|
||||||
return p.remoteConn.Close()
|
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||||
|
return rErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||||
@@ -104,8 +108,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
|
||||||
for {
|
for {
|
||||||
|
buf := make([]byte, 1500)
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
|
|
||||||
e.cancel()
|
e.cancel()
|
||||||
|
|
||||||
if err := e.remoteConn.Close(); err != nil {
|
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
|
|||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
if err := p.remoteConn.Close(); err != nil {
|
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package acl
|
|||||||
import (
|
import (
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -10,14 +11,18 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||||
|
|
||||||
// Manager is a ACL rules manager
|
// Manager is a ACL rules manager
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||||
@@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
||||||
var newRouteRules = make(map[id.RuleID]struct{})
|
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
id, err := d.applyRouteACL(rule)
|
id, err := d.applyRouteACL(rule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("apply route ACL: %w", err)
|
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||||
|
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
|
||||||
|
} else {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
newRouteRules[id] = struct{}{}
|
newRouteRules[id] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean up old firewall rules
|
||||||
for id := range d.routeRules {
|
for id := range d.routeRules {
|
||||||
if _, ok := newRouteRules[id]; !ok {
|
if _, exists := newRouteRules[id]; !exists {
|
||||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||||
log.Errorf("failed to delete route firewall rule: %v", err)
|
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
delete(d.routeRules, id)
|
// implicitly deleted from the map
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d.routeRules = newRouteRules
|
d.routeRules = newRouteRules
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
||||||
if len(rule.SourceRanges) == 0 {
|
if len(rule.SourceRanges) == 0 {
|
||||||
return "", fmt.Errorf("source ranges is empty")
|
return "", ErrSourceRangesEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
var sources []netip.Prefix
|
var sources []netip.Prefix
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||||
return cfg, err
|
return cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
|||||||
|
|
||||||
// WriteOutConfig write put the prepared config to the given path
|
// WriteOutConfig write put the prepared config to the given path
|
||||||
func WriteOutConfig(path string, config *Config) error {
|
func WriteOutConfig(path string, config *Config) error {
|
||||||
return util.WriteJson(path, config)
|
return util.WriteJson(context.Background(), path, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||||
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if updated {
|
if updated {
|
||||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,19 +40,21 @@ type ConnectClient struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
engine *Engine
|
engine *Engine
|
||||||
engineMutex sync.Mutex
|
engineMutex sync.Mutex
|
||||||
|
staticInfo *system.StaticInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectClient(
|
func NewConnectClient(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *Config,
|
config *Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
staticInfo *system.StaticInfo,
|
||||||
) *ConnectClient {
|
) *ConnectClient {
|
||||||
return &ConnectClient{
|
return &ConnectClient{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
engineMutex: sync.Mutex{},
|
engineMutex: sync.Mutex{},
|
||||||
|
staticInfo: staticInfo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +159,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
|
|
||||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
c.statusRecorder.MarkManagementDisconnected(state.err)
|
_, err := state.Status()
|
||||||
|
c.statusRecorder.MarkManagementDisconnected(err)
|
||||||
c.statusRecorder.CleanLocalPeerState()
|
c.statusRecorder.CleanLocalPeerState()
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
@@ -178,7 +181,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
|
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
|
||||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey)
|
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.staticInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug(err)
|
log.Debug(err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
@@ -207,7 +210,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
|
|
||||||
c.statusRecorder.MarkSignalDisconnected(nil)
|
c.statusRecorder.MarkSignalDisconnected(nil)
|
||||||
defer func() {
|
defer func() {
|
||||||
c.statusRecorder.MarkSignalDisconnected(state.err)
|
_, err := state.Status()
|
||||||
|
c.statusRecorder.MarkSignalDisconnected(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||||
@@ -230,6 +234,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
|
|
||||||
relayURLs, token := parseRelayInfo(loginResp)
|
relayURLs, token := parseRelayInfo(loginResp)
|
||||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
|
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
|
||||||
|
c.statusRecorder.SetRelayMgr(relayManager)
|
||||||
if len(relayURLs) > 0 {
|
if len(relayURLs) > 0 {
|
||||||
if token != nil {
|
if token != nil {
|
||||||
if err := relayManager.UpdateToken(token); err != nil {
|
if err := relayManager.UpdateToken(token); err != nil {
|
||||||
@@ -240,9 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
|
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
|
||||||
if err = relayManager.Serve(); err != nil {
|
if err = relayManager.Serve(); err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
|
||||||
}
|
}
|
||||||
c.statusRecorder.SetRelayMgr(relayManager)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
@@ -256,7 +259,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks, c.staticInfo)
|
||||||
|
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
@@ -423,14 +426,14 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, staticInfo *system.StaticInfo) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx, staticInfo)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
@@ -323,12 +322,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist dns state right away
|
go func() {
|
||||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
// persist dns state right away
|
||||||
defer cancel()
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
log.Errorf("Failed to persist dns state: %v", err)
|
||||||
log.Errorf("Failed to persist dns state: %v", err)
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
if s.searchDomainNotifier != nil {
|
if s.searchDomainNotifier != nil {
|
||||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||||
@@ -533,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist dns state right away
|
go func() {
|
||||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
defer cancel()
|
l.Errorf("Failed to persist dns state: %v", err)
|
||||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
}
|
||||||
l.Errorf("Failed to persist dns state: %v", err)
|
}()
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
|
|||||||
@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
Port: 53,
|
Port: 53,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Domains: []string{"customdomain.com"},
|
Domains: []string{"google.com"},
|
||||||
Primary: false,
|
Primary: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
if ips[0] != zoneRecords[0].RData {
|
if ips[0] != zoneRecords[0].RData {
|
||||||
t.Fatalf("invalid zone record: %v", err)
|
t.Fatalf("invalid zone record: %v", err)
|
||||||
}
|
}
|
||||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
_, err = resolver.LookupHost(context.Background(), "google.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to resolve: %s", err)
|
t.Errorf("failed to resolve: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -38,7 +39,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -171,7 +171,9 @@ type Engine struct {
|
|||||||
|
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
srWatcher *guard.SRWatcher
|
srWatcher *guard.SRWatcher
|
||||||
|
|
||||||
|
staticInfo *system.StaticInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -180,8 +182,8 @@ type Peer struct {
|
|||||||
WgAllowedIps string
|
WgAllowedIps string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewEngine creates a new Connection Engine
|
// newEngine creates a new Connection Engine
|
||||||
func NewEngine(
|
func newEngine(
|
||||||
clientCtx context.Context,
|
clientCtx context.Context,
|
||||||
clientCancel context.CancelFunc,
|
clientCancel context.CancelFunc,
|
||||||
signalClient signal.Client,
|
signalClient signal.Client,
|
||||||
@@ -203,6 +205,7 @@ func NewEngine(
|
|||||||
statusRecorder,
|
statusRecorder,
|
||||||
nil,
|
nil,
|
||||||
checks,
|
checks,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,6 +221,7 @@ func NewEngineWithProbes(
|
|||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
probes *ProbeHolder,
|
probes *ProbeHolder,
|
||||||
checks []*mgmProto.Checks,
|
checks []*mgmProto.Checks,
|
||||||
|
staticInfo *system.StaticInfo,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
@@ -237,6 +241,7 @@ func NewEngineWithProbes(
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
probes: probes,
|
probes: probes,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
|
staticInfo: staticInfo,
|
||||||
}
|
}
|
||||||
if path := statemanager.GetDefaultStatePath(); path != "" {
|
if path := statemanager.GetDefaultStatePath(); path != "" {
|
||||||
engine.stateManager = statemanager.New(path)
|
engine.stateManager = statemanager.New(path)
|
||||||
@@ -297,7 +302,7 @@ func (e *Engine) Stop() error {
|
|||||||
if err := e.stateManager.Stop(ctx); err != nil {
|
if err := e.stateManager.Stop(ctx); err != nil {
|
||||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||||
}
|
}
|
||||||
if err := e.stateManager.PersistState(ctx); err != nil {
|
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -538,6 +543,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
|
|
||||||
relayMsg := wCfg.GetRelay()
|
relayMsg := wCfg.GetRelay()
|
||||||
if relayMsg != nil {
|
if relayMsg != nil {
|
||||||
|
// when we receive token we expect valid address list too
|
||||||
c := &auth.Token{
|
c := &auth.Token{
|
||||||
Payload: relayMsg.GetTokenPayload(),
|
Payload: relayMsg.GetTokenPayload(),
|
||||||
Signature: relayMsg.GetTokenSignature(),
|
Signature: relayMsg.GetTokenSignature(),
|
||||||
@@ -546,9 +552,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
log.Errorf("failed to update relay token: %v", err)
|
log.Errorf("failed to update relay token: %v", err)
|
||||||
return fmt.Errorf("update relay token: %w", err)
|
return fmt.Errorf("update relay token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.relayManager.UpdateServerURLs(relayMsg.Urls)
|
||||||
|
|
||||||
|
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||||
|
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||||
|
_ = e.relayManager.Serve()
|
||||||
|
} else {
|
||||||
|
e.relayManager.UpdateServerURLs(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo update relay address in the relay manager
|
|
||||||
// todo update signal
|
// todo update signal
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -574,10 +587,10 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
}
|
}
|
||||||
e.checks = checks
|
e.checks = checks
|
||||||
|
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
info, err := system.GetInfoWithChecks(e.ctx, checks, e.staticInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
info = system.GetInfo(e.ctx)
|
info = system.GetInfo(e.ctx, e.staticInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||||
@@ -641,6 +654,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return errors.New("wireguard interface is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
if e.wgInterface.Address().String() != conf.Address {
|
if e.wgInterface.Address().String() != conf.Address {
|
||||||
oldAddr := e.wgInterface.Address().String()
|
oldAddr := e.wgInterface.Address().String()
|
||||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||||
@@ -673,10 +690,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||||
func (e *Engine) receiveManagementEvents() {
|
func (e *Engine) receiveManagementEvents() {
|
||||||
go func() {
|
go func() {
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.staticInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
info = system.GetInfo(e.ctx)
|
info = system.GetInfo(e.ctx, e.staticInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
// err = e.mgmClient.Sync(info, e.handleSync)
|
||||||
@@ -1180,7 +1197,7 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||||
info := system.GetInfo(e.ctx)
|
info := system.GetInfo(e.ctx, e.staticInfo)
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -1481,6 +1498,17 @@ func (e *Engine) stopDNSServer() {
|
|||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
|
for _, check := range checks {
|
||||||
|
sort.Slice(check.Files, func(i, j int) bool {
|
||||||
|
return check.Files[i] < check.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, oCheck := range oChecks {
|
||||||
|
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||||
|
return oCheck.Files[i] < oCheck.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||||
return slices.Equal(checks.Files, oChecks.Files)
|
return slices.Equal(checks.Files, oChecks.Files)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||||
engine := NewEngine(
|
engine := newEngine(
|
||||||
ctx, cancel,
|
ctx, cancel,
|
||||||
&signal.MockClient{},
|
&signal.MockClient{},
|
||||||
&mgmt.MockClient{},
|
&mgmt.MockClient{},
|
||||||
@@ -229,7 +229,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||||
engine := NewEngine(
|
engine := newEngine(
|
||||||
ctx, cancel,
|
ctx, cancel,
|
||||||
&signal.MockClient{},
|
&signal.MockClient{},
|
||||||
&mgmt.MockClient{},
|
&mgmt.MockClient{},
|
||||||
@@ -434,7 +434,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
engine := newEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
||||||
WgIfaceName: "utun103",
|
WgIfaceName: "utun103",
|
||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
@@ -594,7 +594,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
engine := newEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||||
WgIfaceName: wgIfaceName,
|
WgIfaceName: wgIfaceName,
|
||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
@@ -774,7 +774,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
engine := newEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||||
WgIfaceName: wgIfaceName,
|
WgIfaceName: wgIfaceName,
|
||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
@@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_CheckFilesEqual(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputChecks1 []*mgmtProto.Checks
|
||||||
|
inputChecks2 []*mgmtProto.Checks
|
||||||
|
expectedBool bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Equal Files In Equal Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Equal Files In Reverse Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile2",
|
||||||
|
"testfile1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unequal Files Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Compared With Empty Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||||
|
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1025,7 +1118,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := system.GetInfo(ctx)
|
info := system.GetInfo(ctx, nil)
|
||||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil)
|
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1047,7 +1140,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
e, err := newEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||||
e.ctx = ctx
|
e.ctx = ctx
|
||||||
return e, err
|
return e, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// IsLoginRequired check that the server is support SSO or not
|
// IsLoginRequired check that the server is support SSO or not
|
||||||
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) {
|
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string, staticInfo *system.StaticInfo) (bool, error) {
|
||||||
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
|
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
@@ -38,7 +38,7 @@ func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, ss
|
|||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey)
|
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, staticInfo)
|
||||||
if isLoginNeeded(err) {
|
if isLoginNeeded(err) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@@ -46,7 +46,7 @@ func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, ss
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Login or register the client
|
// Login or register the client
|
||||||
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string, staticInfo *system.StaticInfo) error {
|
||||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -67,10 +67,10 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, staticInfo)
|
||||||
if serverKey != nil && isRegistrationNeeded(err) {
|
if serverKey != nil && isRegistrationNeeded(err) {
|
||||||
log.Debugf("peer registration required")
|
log.Debugf("peer registration required")
|
||||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, staticInfo)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,28 +99,28 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
|
|||||||
return mgmClient, err
|
return mgmClient, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) {
|
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, staticInfo *system.StaticInfo) (*wgtypes.Key, error) {
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
serverKey, err := mgmClient.GetServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx, staticInfo)
|
||||||
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
||||||
return serverKey, err
|
return serverKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, staticInfo *system.StaticInfo) (*mgmProto.LoginResponse, error) {
|
||||||
validSetupKey, err := uuid.Parse(setupKey)
|
validSetupKey, err := uuid.Parse(setupKey)
|
||||||
if err != nil && jwtToken == "" {
|
if err != nil && jwtToken == "" {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("sending peer registration request to Management Service")
|
log.Debugf("sending peer registration request to Management Service")
|
||||||
info := system.GetInfo(ctx)
|
info := system.GetInfo(ctx, staticInfo)
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey)
|
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
|
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
|
||||||
|
|||||||
@@ -309,6 +309,11 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
|
||||||
|
conn.log.Errorf("remote ICE connection is nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
conn.log.Debugf("ICE connection is ready")
|
conn.log.Debugf("ICE connection is ready")
|
||||||
|
|
||||||
if conn.currentConnPriority > priority {
|
if conn.currentConnPriority > priority {
|
||||||
@@ -437,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
if conn.iceP2PIsActive() {
|
if conn.iceP2PIsActive() {
|
||||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||||
conn.wgProxyRelay = wgProxy
|
conn.setRelayedProxy(wgProxy)
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.Set(StatusConnected)
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
return
|
return
|
||||||
@@ -460,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = connPriorityRelay
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.Set(StatusConnected)
|
||||||
conn.wgProxyRelay = wgProxy
|
conn.setRelayedProxy(wgProxy)
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
conn.log.Infof("start to communicate with peer via relay")
|
conn.log.Infof("start to communicate with peer via relay")
|
||||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||||
@@ -731,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||||
|
if conn.wgProxyRelay != nil {
|
||||||
|
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||||
|
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn.wgProxyRelay = proxy
|
||||||
|
}
|
||||||
|
|
||||||
func isController(config ConnConfig) bool {
|
func isController(config ConnConfig) bool {
|
||||||
return config.LocalKey > config.Key
|
return config.LocalKey > config.Key
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"time"
|
||||||
|
|
||||||
"github.com/pion/ice/v3"
|
"github.com/pion/ice/v3"
|
||||||
"github.com/pion/randutil"
|
"github.com/pion/randutil"
|
||||||
"github.com/pion/stun/v2"
|
"github.com/pion/stun/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"runtime"
|
|
||||||
"time"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -77,10 +78,7 @@ func CandidateTypes() []ice.CandidateType {
|
|||||||
if hasICEForceRelayConn() {
|
if hasICEForceRelayConn() {
|
||||||
return []ice.CandidateType{ice.CandidateTypeRelay}
|
return []ice.CandidateType{ice.CandidateTypeRelay}
|
||||||
}
|
}
|
||||||
// TODO: remove this once we have refactored userspace proxy into the bind package
|
|
||||||
if runtime.GOOS == "ios" {
|
|
||||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
|
|
||||||
}
|
|
||||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
|
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
21
client/internal/peer/nilcheck.go
Normal file
21
client/internal/peer/nilcheck.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func remoteConnNil(log *log.Entry, conn net.Conn) bool {
|
||||||
|
if conn == nil {
|
||||||
|
log.Errorf("ice conn is nil")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.RemoteAddr() == nil {
|
||||||
|
log.Errorf("ICE remote address is nil")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
|
|||||||
func (s *State) GetRoutes() map[string]struct{} {
|
func (s *State) GetRoutes() map[string]struct{} {
|
||||||
s.Mux.RLock()
|
s.Mux.RLock()
|
||||||
defer s.Mux.RUnlock()
|
defer s.Mux.RUnlock()
|
||||||
return s.routes
|
return maps.Clone(s.routes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalPeerState contains the latest state of the local peer
|
// LocalPeerState contains the latest state of the local peer
|
||||||
@@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
peerState.IP = receivedState.IP
|
peerState.IP = receivedState.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
if receivedState.GetRoutes() != nil {
|
|
||||||
peerState.SetRoutes(receivedState.GetRoutes())
|
|
||||||
}
|
|
||||||
|
|
||||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||||
|
|
||||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||||
@@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerListChanged()
|
||||||
if found && ch != nil {
|
return nil
|
||||||
close(ch)
|
}
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
|
func (d *Status) AddPeerStateRoute(peer string, route string) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[peer]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peerState.AddRoute(route)
|
||||||
|
d.peers[peer] = peerState
|
||||||
|
|
||||||
|
// todo: consider to make sense of this notification or not
|
||||||
|
d.notifyPeerListChanged()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
peerState, ok := d.peers[peer]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("peer doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerState.DeleteRoute(route)
|
||||||
|
d.peers[peer] = peerState
|
||||||
|
|
||||||
|
// todo: consider to make sense of this notification or not
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, found := d.changeNotify[receivedState.PubKey]
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
if found && ch != nil {
|
|
||||||
close(ch)
|
|
||||||
d.changeNotify[receivedState.PubKey] = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.notifyPeerListChanged()
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
|
|||||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
ch, found := d.changeNotify[peer]
|
ch, found := d.changeNotify[peer]
|
||||||
if !found || ch == nil {
|
if found {
|
||||||
ch = make(chan struct{})
|
return ch
|
||||||
d.changeNotify[peer] = ch
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ch = make(chan struct{})
|
||||||
|
d.changeNotify[peer] = ch
|
||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -669,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
|||||||
// extend the list of stun, turn servers with relay address
|
// extend the list of stun, turn servers with relay address
|
||||||
relayStates := slices.Clone(d.relayStates)
|
relayStates := slices.Clone(d.relayStates)
|
||||||
|
|
||||||
var relayState relay.ProbeResult
|
|
||||||
|
|
||||||
// if the server connection is not established then we will use the general address
|
// if the server connection is not established then we will use the general address
|
||||||
// in case of connection we will use the instance specific address
|
// in case of connection we will use the instance specific address
|
||||||
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO add their status
|
// TODO add their status
|
||||||
if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
|
for _, r := range d.relayMgr.ServerURLs() {
|
||||||
for _, r := range d.relayMgr.ServerURLs() {
|
relayStates = append(relayStates, relay.ProbeResult{
|
||||||
relayStates = append(relayStates, relay.ProbeResult{
|
URI: r,
|
||||||
URI: r,
|
Err: err,
|
||||||
})
|
})
|
||||||
}
|
|
||||||
return relayStates
|
|
||||||
}
|
}
|
||||||
relayState.Err = err
|
return relayStates
|
||||||
}
|
}
|
||||||
|
|
||||||
relayState.URI = instanceAddr
|
relayState := relay.ProbeResult{
|
||||||
|
URI: instanceAddr,
|
||||||
|
}
|
||||||
return append(relayStates, relayState)
|
return append(relayStates, relayState)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -755,6 +760,17 @@ func (d *Status) onConnectionChanged() {
|
|||||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||||
|
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||||
|
ch, found := d.changeNotify[peerID]
|
||||||
|
if !found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(ch)
|
||||||
|
delete(d.changeNotify, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Status) notifyPeerListChanged() {
|
func (d *Status) notifyPeerListChanged() {
|
||||||
d.notifier.peerListChanged(d.numOfPeers())
|
d.notifier.peerListChanged(d.numOfPeers())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
|||||||
|
|
||||||
peerState.IP = ip
|
peerState.IP = ip
|
||||||
|
|
||||||
err := status.UpdatePeerState(peerState)
|
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -46,8 +46,6 @@ type WorkerICE struct {
|
|||||||
hasRelayOnLocally bool
|
hasRelayOnLocally bool
|
||||||
conn WorkerICECallbacks
|
conn WorkerICECallbacks
|
||||||
|
|
||||||
selectedPriority ConnPriority
|
|
||||||
|
|
||||||
agent *ice.Agent
|
agent *ice.Agent
|
||||||
muxAgent sync.Mutex
|
muxAgent sync.Mutex
|
||||||
|
|
||||||
@@ -57,6 +55,9 @@ type WorkerICE struct {
|
|||||||
|
|
||||||
localUfrag string
|
localUfrag string
|
||||||
localPwd string
|
localPwd string
|
||||||
|
|
||||||
|
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||||
|
lastKnownState ice.ConnectionState
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
||||||
@@ -92,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
|
|
||||||
var preferredCandidateTypes []ice.CandidateType
|
var preferredCandidateTypes []ice.CandidateType
|
||||||
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
|
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
|
||||||
w.selectedPriority = connPriorityICEP2P
|
|
||||||
preferredCandidateTypes = icemaker.CandidateTypesP2P()
|
preferredCandidateTypes = icemaker.CandidateTypesP2P()
|
||||||
} else {
|
} else {
|
||||||
w.selectedPriority = connPriorityICETurn
|
|
||||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||||
}
|
}
|
||||||
w.log.Debugf("on ICE conn read to use ready")
|
w.log.Debugf("on ICE conn read to use ready")
|
||||||
go w.conn.OnConnReady(w.selectedPriority, ci)
|
go w.conn.OnConnReady(selectedPriority(pair), ci)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||||
@@ -194,8 +193,7 @@ func (w *WorkerICE) Close() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := w.agent.Close()
|
if err := w.agent.Close(); err != nil {
|
||||||
if err != nil {
|
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -215,15 +213,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
|||||||
|
|
||||||
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
||||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||||
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
|
switch state {
|
||||||
w.conn.OnStatusChanged(StatusDisconnected)
|
case ice.ConnectionStateConnected:
|
||||||
|
w.lastKnownState = ice.ConnectionStateConnected
|
||||||
w.muxAgent.Lock()
|
return
|
||||||
agentCancel()
|
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||||
_ = agent.Close()
|
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||||
w.agent = nil
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
|
w.conn.OnStatusChanged(StatusDisconnected)
|
||||||
w.muxAgent.Unlock()
|
}
|
||||||
|
w.closeAgent(agentCancel)
|
||||||
|
default:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -249,6 +250,17 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
|||||||
return agent, nil
|
return agent, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
|
||||||
|
w.muxAgent.Lock()
|
||||||
|
defer w.muxAgent.Unlock()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
if err := w.agent.Close(); err != nil {
|
||||||
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
|
}
|
||||||
|
w.agent = nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
// wait local endpoint configuration
|
// wait local endpoint configuration
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -378,3 +390,11 @@ func isRelayed(pair *ice.CandidatePair) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
|
||||||
|
if isRelayed(pair) {
|
||||||
|
return connPriorityICETurn
|
||||||
|
} else {
|
||||||
|
return connPriorityICEP2P
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
|||||||
tempScore = float64(metricDiff) * 10
|
tempScore = float64(metricDiff) * 10
|
||||||
}
|
}
|
||||||
|
|
||||||
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
|
||||||
latency := time.Second
|
latency := 999 * time.Millisecond
|
||||||
if peerStatus.latency != 0 {
|
if peerStatus.latency != 0 {
|
||||||
latency = peerStatus.latency
|
latency = peerStatus.latency
|
||||||
} else {
|
} else {
|
||||||
log.Warnf("peer %s has 0 latency", r.Peer)
|
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// avoid negative tempScore on the higher latency calculation
|
||||||
|
if latency > 1*time.Second {
|
||||||
|
latency = 999 * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
// higher latency is worse score
|
||||||
tempScore += 1 - latency.Seconds()
|
tempScore += 1 - latency.Seconds()
|
||||||
|
|
||||||
if !peerStatus.relayed {
|
if !peerStatus.relayed {
|
||||||
@@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case chosen == "":
|
case chosen == "":
|
||||||
var peers []string
|
var peers []string
|
||||||
@@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
|||||||
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||||
for _, r := range c.routes {
|
for _, r := range c.routes {
|
||||||
_, found := c.routePeersNotifiers[r.Peer]
|
_, found := c.routePeersNotifiers[r.Peer]
|
||||||
if !found {
|
if found {
|
||||||
c.routePeersNotifiers[r.Peer] = make(chan struct{})
|
continue
|
||||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
closerChan := make(chan struct{})
|
||||||
|
c.routePeersNotifiers[r.Peer] = closerChan
|
||||||
|
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
||||||
c.removeStateRoute()
|
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
|
||||||
|
log.Warnf("Failed to update peer state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||||
@@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
|||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||||
}
|
}
|
||||||
if err := c.handler.RemoveRoute(); err != nil {
|
if err := c.handler.RemoveRoute(); err != nil {
|
||||||
@@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, remove the allowed IPs from the previous peer first
|
// Otherwise, remove the allowed IPs from the previous peer first
|
||||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -268,37 +282,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.addStateRoute()
|
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("add peer state route: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) addStateRoute() {
|
|
||||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to get peer state: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
state.AddRoute(c.handler.String())
|
|
||||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
|
||||||
log.Warnf("Failed to update peer state: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientNetwork) removeStateRoute() {
|
|
||||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to get peer state: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
state.DeleteRoute(c.handler.String())
|
|
||||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
|
||||||
log.Warnf("Failed to update peer state: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||||
go func() {
|
go func() {
|
||||||
c.routeUpdate <- update
|
c.routeUpdate <- update
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
currentRoute: "route1",
|
currentRoute: "route1",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "relayed routes with latency 0 should maintain previous choice",
|
||||||
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "route1",
|
||||||
|
expectedRouteID: "route1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "p2p routes with latency 0 should maintain previous choice",
|
||||||
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "route1",
|
||||||
|
expectedRouteID: "route1",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "current route with bad score should be changed to route with better score",
|
name: "current route with bad score should be changed to route with better score",
|
||||||
statuses: map[route.ID]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
@@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fill the test data with random routes
|
||||||
|
for _, tc := range testCases {
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
dummyRoute := &route.Route{
|
||||||
|
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
|
||||||
|
Metric: route.MinMetric,
|
||||||
|
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||||
|
}
|
||||||
|
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||||
|
}
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
dummyRoute := &route.Route{
|
||||||
|
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
|
||||||
|
Metric: route.MinMetric,
|
||||||
|
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||||
|
}
|
||||||
|
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
|
||||||
|
dummyStatus := routerPeerStatus{
|
||||||
|
connected: false,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0,
|
||||||
|
}
|
||||||
|
tc.statuses[id] = dummyStatus
|
||||||
|
}
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
|
||||||
|
dummyStatus := routerPeerStatus{
|
||||||
|
connected: false,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0,
|
||||||
|
}
|
||||||
|
tc.statuses[id] = dummyStatus
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
currentRoute := &route.Route{
|
currentRoute := &route.Route{
|
||||||
|
|||||||
@@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
|
|||||||
type Counter[Key comparable, I, O any] struct {
|
type Counter[Key comparable, I, O any] struct {
|
||||||
// refCountMap keeps track of the reference Ref for keys
|
// refCountMap keeps track of the reference Ref for keys
|
||||||
refCountMap map[Key]Ref[O]
|
refCountMap map[Key]Ref[O]
|
||||||
refCountMu sync.Mutex
|
mu sync.Mutex
|
||||||
// idMap keeps track of the keys associated with an ID for removal
|
// idMap keeps track of the keys associated with an ID for removal
|
||||||
idMap map[string][]Key
|
idMap map[string][]Key
|
||||||
idMu sync.Mutex
|
|
||||||
add AddFunc[Key, I, O]
|
add AddFunc[Key, I, O]
|
||||||
remove RemoveFunc[Key, O]
|
remove RemoveFunc[Key, O]
|
||||||
}
|
}
|
||||||
@@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
|||||||
func (rm *Counter[Key, I, O]) LoadData(
|
func (rm *Counter[Key, I, O]) LoadData(
|
||||||
existingCounter *Counter[Key, I, O],
|
existingCounter *Counter[Key, I, O],
|
||||||
) {
|
) {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
rm.idMu.Lock()
|
|
||||||
defer rm.idMu.Unlock()
|
|
||||||
|
|
||||||
rm.refCountMap = existingCounter.refCountMap
|
rm.refCountMap = existingCounter.refCountMap
|
||||||
rm.idMap = existingCounter.idMap
|
rm.idMap = existingCounter.idMap
|
||||||
@@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
|
|||||||
// Get retrieves the current reference count and associated data for a key.
|
// Get retrieves the current reference count and associated data for a key.
|
||||||
// If the key doesn't exist, it returns a zero value Ref and false.
|
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||||
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
ref, ok := rm.refCountMap[key]
|
ref, ok := rm.refCountMap[key]
|
||||||
return ref, ok
|
return ref, ok
|
||||||
@@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
|||||||
// Increment increments the reference count for the given key.
|
// Increment increments the reference count for the given key.
|
||||||
// If this is the first reference to the key, the AddFunc is called.
|
// If this is the first reference to the key, the AddFunc is called.
|
||||||
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
|
return rm.increment(key, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
|
||||||
ref := rm.refCountMap[key]
|
ref := rm.refCountMap[key]
|
||||||
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
||||||
|
|
||||||
@@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
|||||||
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
||||||
// If this is the first reference to the key, the AddFunc is called.
|
// If this is the first reference to the key, the AddFunc is called.
|
||||||
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
||||||
rm.idMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.idMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
ref, err := rm.Increment(key, in)
|
ref, err := rm.increment(key, in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ref, fmt.Errorf("with ID: %w", err)
|
return ref, fmt.Errorf("with ID: %w", err)
|
||||||
}
|
}
|
||||||
@@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
|
|||||||
// Decrement decrements the reference count for the given key.
|
// Decrement decrements the reference count for the given key.
|
||||||
// If the reference count reaches 0, the RemoveFunc is called.
|
// If the reference count reaches 0, the RemoveFunc is called.
|
||||||
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
return rm.decrement(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
|
||||||
ref, ok := rm.refCountMap[key]
|
ref, ok := rm.refCountMap[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
logCallerF("No reference found for key %v", key)
|
logCallerF("No reference found for key %v", key)
|
||||||
@@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
|||||||
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
||||||
// If the reference count reaches 0, the RemoveFunc is called.
|
// If the reference count reaches 0, the RemoveFunc is called.
|
||||||
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||||
rm.idMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.idMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, key := range rm.idMap[id] {
|
for _, key := range rm.idMap[id] {
|
||||||
if _, err := rm.Decrement(key); err != nil {
|
if _, err := rm.decrement(key); err != nil {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
|||||||
|
|
||||||
// Flush removes all references and calls RemoveFunc for each key.
|
// Flush removes all references and calls RemoveFunc for each key.
|
||||||
func (rm *Counter[Key, I, O]) Flush() error {
|
func (rm *Counter[Key, I, O]) Flush() error {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
rm.idMu.Lock()
|
|
||||||
defer rm.idMu.Unlock()
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for key := range rm.refCountMap {
|
for key := range rm.refCountMap {
|
||||||
@@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
|
|||||||
|
|
||||||
// Clear removes all references without calling RemoveFunc.
|
// Clear removes all references without calling RemoveFunc.
|
||||||
func (rm *Counter[Key, I, O]) Clear() {
|
func (rm *Counter[Key, I, O]) Clear() {
|
||||||
rm.refCountMu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.refCountMu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
rm.idMu.Lock()
|
|
||||||
defer rm.idMu.Unlock()
|
|
||||||
|
|
||||||
clear(rm.refCountMap)
|
clear(rm.refCountMap)
|
||||||
clear(rm.idMap)
|
clear(rm.idMap)
|
||||||
@@ -217,6 +217,9 @@ func (rm *Counter[Key, I, O]) Clear() {
|
|||||||
|
|
||||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||||
|
rm.mu.Lock()
|
||||||
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
return json.Marshal(struct {
|
return json.Marshal(struct {
|
||||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||||
IDMap map[string][]Key `json:"idMap"`
|
IDMap map[string][]Key `json:"idMap"`
|
||||||
|
|||||||
@@ -2,31 +2,28 @@ package systemops
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState ExclusionCounter
|
||||||
Counter *ExclusionCounter `json:"counter,omitempty"`
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
return "route_state"
|
return "route_state"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
if s.Counter == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sysops := NewSysOps(nil, nil)
|
sysops := NewSysOps(nil, nil)
|
||||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||||
sysops.refCounter.LoadData(s.Counter)
|
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
||||||
|
|
||||||
return sysops.refCounter.Flush()
|
return sysops.refCounter.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||||
|
return (*ExclusionCounter)(s).MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
||||||
|
return (*ExclusionCounter)(s).UnmarshalJSON(data)
|
||||||
|
}
|
||||||
|
|||||||
@@ -57,30 +57,19 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
|||||||
return nexthop, refcounter.ErrIgnore
|
return nexthop, refcounter.ErrIgnore
|
||||||
}
|
}
|
||||||
|
|
||||||
r.updateState(stateManager)
|
|
||||||
|
|
||||||
return nexthop, err
|
return nexthop, err
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, nexthop Nexthop) error {
|
r.removeFromRouteTable,
|
||||||
// remove from state even if we have trouble removing it from the route table
|
|
||||||
// it could be already gone
|
|
||||||
r.updateState(stateManager)
|
|
||||||
|
|
||||||
return r.removeFromRouteTable(prefix, nexthop)
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
r.refCounter = refCounter
|
r.refCounter = refCounter
|
||||||
|
|
||||||
return r.setupHooks(initAddresses)
|
return r.setupHooks(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateState updates state on every change so it will be persisted regularly
|
||||||
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
||||||
state := getState(stateManager)
|
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
|
||||||
|
|
||||||
state.Counter = r.refCounter
|
|
||||||
|
|
||||||
if err := stateManager.UpdateState(state); err != nil {
|
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -336,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
|||||||
return r.removeFromRouteTable(prefix, nextHop)
|
return r.removeFromRouteTable(prefix, nextHop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||||
prefix, err := util.GetPrefixFromIP(ip)
|
prefix, err := util.GetPrefixFromIP(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -347,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
|||||||
return fmt.Errorf("adding route reference: %v", err)
|
return fmt.Errorf("adding route reference: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState(stateManager)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
afterHook := func(connID nbnet.ConnectionID) error {
|
afterHook := func(connID nbnet.ConnectionID) error {
|
||||||
@@ -354,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
|||||||
return fmt.Errorf("remove route reference: %w", err)
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState(stateManager)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -532,14 +525,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
|
|||||||
// Return true if the longest matching prefix is from vpnRoutes
|
// Return true if the longest matching prefix is from vpnRoutes
|
||||||
return isVpn, longestPrefix
|
return isVpn, longestPrefix
|
||||||
}
|
}
|
||||||
|
|
||||||
func getState(stateManager *statemanager.Manager) *ShutdownState {
|
|
||||||
var shutdownState *ShutdownState
|
|
||||||
if state := stateManager.GetState(shutdownState); state != nil {
|
|
||||||
shutdownState = state.(*ShutdownState)
|
|
||||||
} else {
|
|
||||||
shutdownState = &ShutdownState{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return shutdownState
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ type ruleParams struct {
|
|||||||
|
|
||||||
// isLegacy determines whether to use the legacy routing setup
|
// isLegacy determines whether to use the legacy routing setup
|
||||||
func isLegacy() bool {
|
func isLegacy() bool {
|
||||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
||||||
}
|
}
|
||||||
|
|
||||||
// setIsLegacy sets the legacy routing setup
|
// setIsLegacy sets the legacy routing setup
|
||||||
@@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = addRoutingTableName(); err != nil {
|
|
||||||
log.Errorf("Error adding routing table name: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Error setting up sysctl: %v", err)
|
|
||||||
sysctlFailed = true
|
|
||||||
}
|
|
||||||
originalSysctl = originalValues
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||||
@@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = addRoutingTableName(); err != nil {
|
||||||
|
log.Errorf("Error adding routing table name: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error setting up sysctl: %v", err)
|
||||||
|
sysctlFailed = true
|
||||||
|
}
|
||||||
|
originalSysctl = originalValues
|
||||||
|
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
|
|||||||
rule.Invert = params.invert
|
rule.Invert = params.invert
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||||
return fmt.Errorf("add routing rule: %w", err)
|
return fmt.Errorf("add routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
|
|||||||
rule.Priority = params.priority
|
rule.Priority = params.priority
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) {
|
||||||
return fmt.Errorf("remove routing rule: %w", err)
|
return fmt.Errorf("remove routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// State interface defines the methods that all state types must implement
|
// State interface defines the methods that all state types must implement
|
||||||
@@ -73,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
if m.cancel != nil {
|
if m.cancel == nil {
|
||||||
m.cancel()
|
return nil
|
||||||
|
}
|
||||||
|
m.cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-m.done:
|
case <-m.done:
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -178,25 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
bs, err := marshalWithPanicRecovery(m.states)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal states: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
|
start := time.Now()
|
||||||
go func() {
|
go func() {
|
||||||
data, err := json.MarshalIndent(m.states, "", " ")
|
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
|
||||||
if err != nil {
|
|
||||||
done <- fmt.Errorf("marshal states: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:gosec
|
|
||||||
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
|
|
||||||
done <- fmt.Errorf("write state file: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
done <- nil
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -208,7 +202,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
|
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
|
||||||
|
|
||||||
clear(m.dirty)
|
clear(m.dirty)
|
||||||
|
|
||||||
@@ -296,3 +290,19 @@ func (m *Manager) PerformCleanup() error {
|
|||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func marshalWithPanicRecovery(v any) ([]byte, error) {
|
||||||
|
var bs []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic during marshal: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
bs, err = json.Marshal(v)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return bs, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,32 +4,20 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||||
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
|
// It returns an empty string if the path cannot be determined.
|
||||||
func GetDefaultStatePath() string {
|
func GetDefaultStatePath() string {
|
||||||
var path string
|
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "windows":
|
case "windows":
|
||||||
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||||
case "darwin", "linux":
|
case "darwin", "linux":
|
||||||
path = "/var/lib/netbird/state.json"
|
return "/var/lib/netbird/state.json"
|
||||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||||
path = "/var/db/netbird/state.json"
|
return "/var/db/netbird/state.json"
|
||||||
// ios/android don't need state
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dir := filepath.Dir(path)
|
return ""
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return path
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ func (c *Client) Run(fd int32, interfaceName string) error {
|
|||||||
c.onHostDnsFn = func([]string) {}
|
c.onHostDnsFn = func([]string) {}
|
||||||
cfg.WgIface = interfaceName
|
cfg.WgIface = interfaceName
|
||||||
|
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, &system.StaticInfo{})
|
||||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager)
|
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,7 +204,7 @@ func (c *Client) IsLoginRequired() bool {
|
|||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
|
|
||||||
needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey)
|
needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey, &system.StaticInfo{})
|
||||||
return needsLogin
|
return needsLogin
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -244,7 +244,7 @@ func (c *Client) LoginForMobile() string {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwtToken := tokenInfo.GetTokenToUse()
|
jwtToken := tokenInfo.GetTokenToUse()
|
||||||
_ = internal.Login(ctx, cfg, "", jwtToken)
|
_ = internal.Login(ctx, cfg, "", jwtToken, &system.StaticInfo{})
|
||||||
c.loginComplete = true
|
c.loginComplete = true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
err := a.withBackOff(a.ctx, func() error {
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "", &system.StaticInfo{})
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
// we got an answer from management, exit backoff earlier
|
// we got an answer from management, exit backoff earlier
|
||||||
return backoff.Permanent(backoffErr)
|
return backoff.Permanent(backoffErr)
|
||||||
@@ -123,7 +123,7 @@ func (a *Auth) Login() error {
|
|||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey, &system.StaticInfo{})
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -136,7 +136,7 @@ func (a *Auth) Login() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err = a.withBackOff(a.ctx, func() error {
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
err := internal.Login(a.ctx, a.config, "", jwtToken, &system.StaticInfo{})
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
7
client/server/panic_generic.go
Normal file
7
client/server/panic_generic.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
func handlePanicLog() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
83
client/server/panic_windows.go
Normal file
83
client/server/panic_windows.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
|
||||||
|
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
|
||||||
|
stdErrorHandle = ^uintptr(11)
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||||
|
|
||||||
|
// https://learn.microsoft.com/en-us/windows/console/setstdhandle
|
||||||
|
setStdHandleFn = kernel32.NewProc("SetStdHandle")
|
||||||
|
)
|
||||||
|
|
||||||
|
func handlePanicLog() error {
|
||||||
|
logPath := os.Getenv(windowsPanicLogEnvVar)
|
||||||
|
if logPath == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the directory exists
|
||||||
|
logDir := filepath.Dir(logPath)
|
||||||
|
if err := os.MkdirAll(logDir, 0750); err != nil {
|
||||||
|
return fmt.Errorf("create panic log directory: %w", err)
|
||||||
|
}
|
||||||
|
if err := util.EnforcePermission(logPath); err != nil {
|
||||||
|
return fmt.Errorf("enforce permission on panic log file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open log file with append mode
|
||||||
|
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open panic log file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect stderr to the file
|
||||||
|
if err = redirectStderr(f); err != nil {
|
||||||
|
if closeErr := f.Close(); closeErr != nil {
|
||||||
|
log.Warnf("failed to close file after redirect error: %v", closeErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("redirect stderr: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("successfully configured panic logging to: %s", logPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirectStderr redirects stderr to the provided file
|
||||||
|
func redirectStderr(f *os.File) error {
|
||||||
|
// Get the current process's stderr handle
|
||||||
|
if err := setStdHandle(f); err != nil {
|
||||||
|
return fmt.Errorf("failed to set stderr handle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also set os.Stderr for Go's standard library
|
||||||
|
os.Stderr = f
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setStdHandle(f *os.File) error {
|
||||||
|
handle := f.Fd()
|
||||||
|
r0, _, e1 := setStdHandleFn.Call(stdErrorHandle, handle)
|
||||||
|
if r0 == 0 {
|
||||||
|
if e1 != nil {
|
||||||
|
return e1
|
||||||
|
}
|
||||||
|
return syscall.EINVAL
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -68,6 +68,7 @@ type Server struct {
|
|||||||
relayProbe *internal.Probe
|
relayProbe *internal.Probe
|
||||||
wgProbe *internal.Probe
|
wgProbe *internal.Probe
|
||||||
lastProbe time.Time
|
lastProbe time.Time
|
||||||
|
staticInfo *system.StaticInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauthAuthFlow struct {
|
type oauthAuthFlow struct {
|
||||||
@@ -79,6 +80,8 @@ type oauthAuthFlow struct {
|
|||||||
|
|
||||||
// New server instance constructor.
|
// New server instance constructor.
|
||||||
func New(ctx context.Context, configPath, logFile string) *Server {
|
func New(ctx context.Context, configPath, logFile string) *Server {
|
||||||
|
staticInfoChan := system.GetStaticInfoInBackground(ctx)
|
||||||
|
staticInfo := <-staticInfoChan
|
||||||
return &Server{
|
return &Server{
|
||||||
rootCtx: ctx,
|
rootCtx: ctx,
|
||||||
latestConfigInput: internal.ConfigInput{
|
latestConfigInput: internal.ConfigInput{
|
||||||
@@ -89,6 +92,7 @@ func New(ctx context.Context, configPath, logFile string) *Server {
|
|||||||
signalProbe: internal.NewProbe(),
|
signalProbe: internal.NewProbe(),
|
||||||
relayProbe: internal.NewProbe(),
|
relayProbe: internal.NewProbe(),
|
||||||
wgProbe: internal.NewProbe(),
|
wgProbe: internal.NewProbe(),
|
||||||
|
staticInfo: staticInfo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,6 +101,10 @@ func (s *Server) Start() error {
|
|||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
|
||||||
|
if err := handlePanicLog(); err != nil {
|
||||||
|
log.Warnf("failed to redirect stderr: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := restoreResidualState(s.rootCtx); err != nil {
|
if err := restoreResidualState(s.rootCtx); err != nil {
|
||||||
log.Warnf(errRestoreResidualState, err)
|
log.Warnf(errRestoreResidualState, err)
|
||||||
}
|
}
|
||||||
@@ -191,7 +199,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
|
|||||||
|
|
||||||
runOperation := func() error {
|
runOperation := func() error {
|
||||||
log.Tracef("running client connection")
|
log.Tracef("running client connection")
|
||||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, s.staticInfo)
|
||||||
|
|
||||||
probes := internal.ProbeHolder{
|
probes := internal.ProbeHolder{
|
||||||
MgmProbe: s.mgmProbe,
|
MgmProbe: s.mgmProbe,
|
||||||
@@ -268,7 +276,7 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
|
|||||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||||
var status internal.StatusType
|
var status internal.StatusType
|
||||||
err := internal.Login(ctx, s.config, setupKey, jwtToken)
|
err := internal.Login(ctx, s.config, setupKey, jwtToken, s.staticInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
log.Warnf("failed login: %v", err)
|
log.Warnf("failed login: %v", err)
|
||||||
@@ -622,6 +630,8 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
|||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
s.oauthAuthFlow = oauthAuthFlow{}
|
||||||
|
|
||||||
if s.actCancel == nil {
|
if s.actCancel == nil {
|
||||||
return nil, fmt.Errorf("service is not up")
|
return nil, fmt.Errorf("service is not up")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,14 @@ type Info struct {
|
|||||||
Files []File // for posture checks
|
Files []File // for posture checks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StaticInfo is an object that contains machine information that does not change
|
||||||
|
type StaticInfo struct {
|
||||||
|
SystemSerialNumber string
|
||||||
|
SystemProductName string
|
||||||
|
SystemManufacturer string
|
||||||
|
Environment Environment
|
||||||
|
}
|
||||||
|
|
||||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||||
func extractUserAgent(ctx context.Context) string {
|
func extractUserAgent(ctx context.Context) string {
|
||||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||||
@@ -142,7 +150,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, staticInfo *StaticInfo) (*Info, error) {
|
||||||
processCheckPaths := make([]string, 0)
|
processCheckPaths := make([]string, 0)
|
||||||
for _, check := range checks {
|
for _, check := range checks {
|
||||||
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
||||||
@@ -153,8 +161,17 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
info := GetInfo(ctx)
|
info := GetInfo(ctx, staticInfo)
|
||||||
info.Files = files
|
info.Files = files
|
||||||
|
|
||||||
return info, nil
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStaticInfoInBackground retrieves and parses the system information in the background
|
||||||
|
func GetStaticInfoInBackground(ctx context.Context) <-chan *StaticInfo {
|
||||||
|
ch := make(chan *StaticInfo)
|
||||||
|
go func() {
|
||||||
|
ch <- getStaticInfo(ctx)
|
||||||
|
}()
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
func GetInfo(ctx context.Context) *Info {
|
func GetInfo(ctx context.Context, _ *StaticInfo) *Info {
|
||||||
kernel := "android"
|
kernel := "android"
|
||||||
osInfo := uname()
|
osInfo := uname()
|
||||||
if len(osInfo) == 2 {
|
if len(osInfo) == 2 {
|
||||||
@@ -44,6 +44,10 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
return gio
|
return gio
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||||
return []File{}, nil
|
return []File{}, nil
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
func GetInfo(ctx context.Context) *Info {
|
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
|
||||||
utsname := unix.Utsname{}
|
utsname := unix.Utsname{}
|
||||||
err := unix.Uname(&utsname)
|
err := unix.Uname(&utsname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -41,26 +41,22 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
log.Warnf("failed to discover network addresses: %s", err)
|
log.Warnf("failed to discover network addresses: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
serialNum, prodName, manufacturer := sysInfo()
|
gio := &Info{
|
||||||
|
Kernel: sysName,
|
||||||
env := Environment{
|
OSVersion: strings.TrimSpace(string(swVersion)),
|
||||||
Cloud: detect_cloud.Detect(ctx),
|
Platform: machine,
|
||||||
Platform: detect_platform.Detect(ctx),
|
OS: sysName,
|
||||||
|
GoOS: runtime.GOOS,
|
||||||
|
CPUs: runtime.NumCPU(),
|
||||||
|
KernelVersion: release,
|
||||||
|
NetworkAddresses: addrs,
|
||||||
}
|
}
|
||||||
|
|
||||||
gio := &Info{
|
if staticInfo != nil {
|
||||||
Kernel: sysName,
|
gio.SystemSerialNumber = staticInfo.SystemSerialNumber
|
||||||
OSVersion: strings.TrimSpace(string(swVersion)),
|
gio.SystemProductName = staticInfo.SystemProductName
|
||||||
Platform: machine,
|
gio.SystemManufacturer = staticInfo.SystemManufacturer
|
||||||
OS: sysName,
|
gio.Environment = staticInfo.Environment
|
||||||
GoOS: runtime.GOOS,
|
|
||||||
CPUs: runtime.NumCPU(),
|
|
||||||
KernelVersion: release,
|
|
||||||
NetworkAddresses: addrs,
|
|
||||||
SystemSerialNumber: serialNum,
|
|
||||||
SystemProductName: prodName,
|
|
||||||
SystemManufacturer: manufacturer,
|
|
||||||
Environment: env,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
systemHostname, _ := os.Hostname()
|
systemHostname, _ := os.Hostname()
|
||||||
@@ -71,6 +67,21 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
return gio
|
return gio
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||||
|
serialNum, prodName, manufacturer := sysInfo()
|
||||||
|
env := Environment{
|
||||||
|
Cloud: detect_cloud.Detect(ctx),
|
||||||
|
Platform: detect_platform.Detect(ctx),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &StaticInfo{
|
||||||
|
SystemSerialNumber: serialNum,
|
||||||
|
SystemProductName: prodName,
|
||||||
|
SystemManufacturer: manufacturer,
|
||||||
|
Environment: env,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||||
out, _ := exec.Command("/usr/sbin/ioreg", "-l").Output() // err ignored for brevity
|
out, _ := exec.Command("/usr/sbin/ioreg", "-l").Output() // err ignored for brevity
|
||||||
for _, l := range strings.Split(string(out), "\n") {
|
for _, l := range strings.Split(string(out), "\n") {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
func GetInfo(ctx context.Context) *Info {
|
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
|
||||||
out := _getInfo()
|
out := _getInfo()
|
||||||
for strings.Contains(out, "broken pipe") {
|
for strings.Contains(out, "broken pipe") {
|
||||||
out = _getInfo()
|
out = _getInfo()
|
||||||
@@ -29,16 +29,11 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
||||||
osInfo := strings.Split(osStr, " ")
|
osInfo := strings.Split(osStr, " ")
|
||||||
|
|
||||||
env := Environment{
|
|
||||||
Cloud: detect_cloud.Detect(ctx),
|
|
||||||
Platform: detect_platform.Detect(ctx),
|
|
||||||
}
|
|
||||||
|
|
||||||
osName, osVersion := readOsReleaseFile()
|
osName, osVersion := readOsReleaseFile()
|
||||||
|
|
||||||
systemHostname, _ := os.Hostname()
|
systemHostname, _ := os.Hostname()
|
||||||
|
|
||||||
return &Info{
|
info := &Info{
|
||||||
GoOS: runtime.GOOS,
|
GoOS: runtime.GOOS,
|
||||||
Kernel: osInfo[0],
|
Kernel: osInfo[0],
|
||||||
Platform: runtime.GOARCH,
|
Platform: runtime.GOARCH,
|
||||||
@@ -49,7 +44,25 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
WiretrusteeVersion: version.NetbirdVersion(),
|
WiretrusteeVersion: version.NetbirdVersion(),
|
||||||
UIVersion: extractUserAgent(ctx),
|
UIVersion: extractUserAgent(ctx),
|
||||||
KernelVersion: osInfo[1],
|
KernelVersion: osInfo[1],
|
||||||
Environment: env,
|
}
|
||||||
|
if staticInfo != nil {
|
||||||
|
info.SystemSerialNumber = staticInfo.SystemSerialNumber
|
||||||
|
info.SystemProductName = staticInfo.SystemProductName
|
||||||
|
info.SystemManufacturer = staticInfo.SystemManufacturer
|
||||||
|
info.Environment = staticInfo.Environment
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||||
|
env := Environment{
|
||||||
|
Cloud: detect_cloud.Detect(ctx),
|
||||||
|
Platform: detect_platform.Detect(ctx),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &StaticInfo{
|
||||||
|
Environment: env,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
func GetInfo(ctx context.Context) *Info {
|
func GetInfo(ctx context.Context, _ *StaticInfo) *Info {
|
||||||
|
|
||||||
// Convert fixed-size byte arrays to Go strings
|
// Convert fixed-size byte arrays to Go strings
|
||||||
sysName := extractOsName(ctx, "sysName")
|
sysName := extractOsName(ctx, "sysName")
|
||||||
@@ -25,6 +25,10 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
return gio
|
return gio
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||||
return []File{}, nil
|
return []File{}, nil
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (s SysInfoWrapper) GetSysInfo() SysInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
func GetInfo(ctx context.Context) *Info {
|
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
|
||||||
info := _getInfo()
|
info := _getInfo()
|
||||||
for strings.Contains(info, "broken pipe") {
|
for strings.Contains(info, "broken pipe") {
|
||||||
info = _getInfo()
|
info = _getInfo()
|
||||||
@@ -65,14 +65,6 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
log.Warnf("failed to discover network addresses: %s", err)
|
log.Warnf("failed to discover network addresses: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
si := SysInfoWrapper{}
|
|
||||||
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
|
|
||||||
|
|
||||||
env := Environment{
|
|
||||||
Cloud: detect_cloud.Detect(ctx),
|
|
||||||
Platform: detect_platform.Detect(ctx),
|
|
||||||
}
|
|
||||||
|
|
||||||
gio := &Info{
|
gio := &Info{
|
||||||
Kernel: osInfo[0],
|
Kernel: osInfo[0],
|
||||||
Platform: osInfo[2],
|
Platform: osInfo[2],
|
||||||
@@ -85,13 +77,32 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
UIVersion: extractUserAgent(ctx),
|
UIVersion: extractUserAgent(ctx),
|
||||||
KernelVersion: osInfo[1],
|
KernelVersion: osInfo[1],
|
||||||
NetworkAddresses: addrs,
|
NetworkAddresses: addrs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if staticInfo != nil {
|
||||||
|
gio.SystemSerialNumber = staticInfo.SystemSerialNumber
|
||||||
|
gio.SystemProductName = staticInfo.SystemProductName
|
||||||
|
gio.SystemManufacturer = staticInfo.SystemManufacturer
|
||||||
|
gio.Environment = staticInfo.Environment
|
||||||
|
}
|
||||||
|
|
||||||
|
return gio
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||||
|
si := SysInfoWrapper{}
|
||||||
|
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
|
||||||
|
env := Environment{
|
||||||
|
Cloud: detect_cloud.Detect(ctx),
|
||||||
|
Platform: detect_platform.Detect(ctx),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &StaticInfo{
|
||||||
SystemSerialNumber: serialNum,
|
SystemSerialNumber: serialNum,
|
||||||
SystemProductName: prodName,
|
SystemProductName: prodName,
|
||||||
SystemManufacturer: manufacturer,
|
SystemManufacturer: manufacturer,
|
||||||
Environment: env,
|
Environment: env,
|
||||||
}
|
}
|
||||||
|
|
||||||
return gio
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func _getInfo() string {
|
func _getInfo() string {
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func Test_LocalWTVersion(t *testing.T) {
|
func Test_LocalWTVersion(t *testing.T) {
|
||||||
got := GetInfo(context.TODO())
|
got := GetInfo(context.TODO(), nil)
|
||||||
want := "development"
|
want := "development"
|
||||||
assert.Equal(t, want, got.WiretrusteeVersion)
|
assert.Equal(t, want, got.WiretrusteeVersion)
|
||||||
}
|
}
|
||||||
@@ -21,7 +21,7 @@ func Test_UIVersion(t *testing.T) {
|
|||||||
"user-agent": {want},
|
"user-agent": {want},
|
||||||
})
|
})
|
||||||
|
|
||||||
got := GetInfo(ctx)
|
got := GetInfo(ctx, nil)
|
||||||
assert.Equal(t, want, got.UIVersion)
|
assert.Equal(t, want, got.UIVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ func Test_CustomHostname(t *testing.T) {
|
|||||||
ctx := context.WithValue(context.Background(), DeviceNameCtxKey, "custom-host")
|
ctx := context.WithValue(context.Background(), DeviceNameCtxKey, "custom-host")
|
||||||
want := "custom-host"
|
want := "custom-host"
|
||||||
|
|
||||||
got := GetInfo(ctx)
|
got := GetInfo(ctx, nil)
|
||||||
assert.Equal(t, want, got.Hostname)
|
assert.Equal(t, want, got.Hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ type Win32_BIOS struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
func GetInfo(ctx context.Context) *Info {
|
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
|
||||||
osName, osVersion := getOSNameAndVersion()
|
osName, osVersion := getOSNameAndVersion()
|
||||||
buildVersion := getBuildVersion()
|
buildVersion := getBuildVersion()
|
||||||
|
|
||||||
@@ -42,39 +42,22 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
log.Warnf("failed to discover network addresses: %s", err)
|
log.Warnf("failed to discover network addresses: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
serialNum, err := sysNumber()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to get system serial number: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
prodName, err := sysProductName()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to get system product name: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
manufacturer, err := sysManufacturer()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to get system manufacturer: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
env := Environment{
|
|
||||||
Cloud: detect_cloud.Detect(ctx),
|
|
||||||
Platform: detect_platform.Detect(ctx),
|
|
||||||
}
|
|
||||||
|
|
||||||
gio := &Info{
|
gio := &Info{
|
||||||
Kernel: "windows",
|
Kernel: "windows",
|
||||||
OSVersion: osVersion,
|
OSVersion: osVersion,
|
||||||
Platform: "unknown",
|
Platform: "unknown",
|
||||||
OS: osName,
|
OS: osName,
|
||||||
GoOS: runtime.GOOS,
|
GoOS: runtime.GOOS,
|
||||||
CPUs: runtime.NumCPU(),
|
CPUs: runtime.NumCPU(),
|
||||||
KernelVersion: buildVersion,
|
KernelVersion: buildVersion,
|
||||||
NetworkAddresses: addrs,
|
NetworkAddresses: addrs,
|
||||||
SystemSerialNumber: serialNum,
|
}
|
||||||
SystemProductName: prodName,
|
|
||||||
SystemManufacturer: manufacturer,
|
if staticInfo != nil {
|
||||||
Environment: env,
|
gio.SystemSerialNumber = staticInfo.SystemSerialNumber
|
||||||
|
gio.SystemProductName = staticInfo.SystemProductName
|
||||||
|
gio.SystemManufacturer = staticInfo.SystemManufacturer
|
||||||
|
gio.Environment = staticInfo.Environment
|
||||||
}
|
}
|
||||||
|
|
||||||
systemHostname, _ := os.Hostname()
|
systemHostname, _ := os.Hostname()
|
||||||
@@ -85,6 +68,41 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
return gio
|
return gio
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||||
|
serialNum, prodName, manufacturer := sysInfo()
|
||||||
|
env := Environment{
|
||||||
|
Cloud: detect_cloud.Detect(ctx),
|
||||||
|
Platform: detect_platform.Detect(ctx),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &StaticInfo{
|
||||||
|
SystemSerialNumber: serialNum,
|
||||||
|
SystemProductName: prodName,
|
||||||
|
SystemManufacturer: manufacturer,
|
||||||
|
Environment: env,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||||
|
var err error
|
||||||
|
serialNumber, err = sysNumber()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get system serial number: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
productName, err = sysProductName()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get system product name: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
manufacturer, err = sysManufacturer()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get system manufacturer: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return serialNumber, productName, manufacturer
|
||||||
|
}
|
||||||
|
|
||||||
func getOSNameAndVersion() (string, string) {
|
func getOSNameAndVersion() (string, string) {
|
||||||
var dst []Win32_OperatingSystem
|
var dst []Win32_OperatingSystem
|
||||||
query := wmi.CreateQuery(&dst, "")
|
query := wmi.CreateQuery(&dst, "")
|
||||||
|
|||||||
126
funding.json
Normal file
126
funding.json
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
{
|
||||||
|
"version": "v1.0.0",
|
||||||
|
"entity": {
|
||||||
|
"type": "organisation",
|
||||||
|
"role": "owner",
|
||||||
|
"name": "NetBird GmbH",
|
||||||
|
"email": "hello@netbird.io",
|
||||||
|
"phone": "",
|
||||||
|
"description": "NetBird GmbH is a Berlin-based software company specializing in the development of open-source network security solutions. Network security is utterly complex and expensive, accessible only to companies with multi-million dollar IT budgets. In contrast, there are millions of companies left behind. Our mission is to create an advanced network and cybersecurity platform that is both easy-to-use and affordable for teams of all sizes and budgets. By leveraging the open-source strategy and technological advancements, NetBird aims to set the industry standard for connecting and securing IT infrastructure.",
|
||||||
|
"webpageUrl": {
|
||||||
|
"url": "https://github.com/netbirdio"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"guid": "netbird",
|
||||||
|
"name": "NetBird",
|
||||||
|
"description": "NetBird is a configuration-free peer-to-peer private network and a centralized access control system combined in a single open-source platform. It makes it easy to create secure WireGuard-based private networks for your organization or home.",
|
||||||
|
"webpageUrl": {
|
||||||
|
"url": "https://github.com/netbirdio/netbird"
|
||||||
|
},
|
||||||
|
"repositoryUrl": {
|
||||||
|
"url": "https://github.com/netbirdio/netbird"
|
||||||
|
},
|
||||||
|
"licenses": [
|
||||||
|
"BSD-3"
|
||||||
|
],
|
||||||
|
"tags": [
|
||||||
|
"network-security",
|
||||||
|
"vpn",
|
||||||
|
"developer-tools",
|
||||||
|
"ztna",
|
||||||
|
"zero-trust",
|
||||||
|
"remote-access",
|
||||||
|
"wireguard",
|
||||||
|
"peer-to-peer",
|
||||||
|
"private-networking",
|
||||||
|
"software-defined-networking"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"funding": {
|
||||||
|
"channels": [
|
||||||
|
{
|
||||||
|
"guid": "github-sponsors",
|
||||||
|
"type": "payment-provider",
|
||||||
|
"address": "https://github.com/sponsors/netbirdio",
|
||||||
|
"description": ""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"guid": "bank-transfer",
|
||||||
|
"type": "bank",
|
||||||
|
"address": "",
|
||||||
|
"description": "Contact us at hello@netbird.io for bank transfer details."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"plans": [
|
||||||
|
{
|
||||||
|
"guid": "support-yearly",
|
||||||
|
"status": "active",
|
||||||
|
"name": "Support Open Source Development and Maintenance - Yearly",
|
||||||
|
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
|
||||||
|
"amount": 100000,
|
||||||
|
"currency": "USD",
|
||||||
|
"frequency": "yearly",
|
||||||
|
"channels": [
|
||||||
|
"github-sponsors",
|
||||||
|
"bank-transfer"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"guid": "support-one-time-year",
|
||||||
|
"status": "active",
|
||||||
|
"name": "Support Open Source Development and Maintenance - One Year",
|
||||||
|
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
|
||||||
|
"amount": 100000,
|
||||||
|
"currency": "USD",
|
||||||
|
"frequency": "one-time",
|
||||||
|
"channels": [
|
||||||
|
"github-sponsors",
|
||||||
|
"bank-transfer"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"guid": "support-one-time-monthly",
|
||||||
|
"status": "active",
|
||||||
|
"name": "Support Open Source Development and Maintenance - Monthly",
|
||||||
|
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
|
||||||
|
"amount": 10000,
|
||||||
|
"currency": "USD",
|
||||||
|
"frequency": "monthly",
|
||||||
|
"channels": [
|
||||||
|
"github-sponsors",
|
||||||
|
"bank-transfer"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"guid": "support-monthly",
|
||||||
|
"status": "active",
|
||||||
|
"name": "Support Open Source Development and Maintenance - One Month",
|
||||||
|
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
|
||||||
|
"amount": 10000,
|
||||||
|
"currency": "USD",
|
||||||
|
"frequency": "monthly",
|
||||||
|
"channels": [
|
||||||
|
"github-sponsors",
|
||||||
|
"bank-transfer"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"guid": "goodwill",
|
||||||
|
"status": "active",
|
||||||
|
"name": "Goodwill Plan",
|
||||||
|
"description": "Pay anything you wish to show your goodwill for the project.",
|
||||||
|
"amount": 0,
|
||||||
|
"currency": "USD",
|
||||||
|
"frequency": "monthly",
|
||||||
|
"channels": [
|
||||||
|
"github-sponsors",
|
||||||
|
"bank-transfer"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"history": null
|
||||||
|
}
|
||||||
|
}
|
||||||
11
go.mod
11
go.mod
@@ -60,7 +60,7 @@ require (
|
|||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
@@ -71,7 +71,6 @@ require (
|
|||||||
github.com/pion/transport/v3 v3.0.1
|
github.com/pion/transport/v3 v3.0.1
|
||||||
github.com/pion/turn/v3 v3.0.1
|
github.com/pion/turn/v3 v3.0.1
|
||||||
github.com/prometheus/client_golang v1.19.1
|
github.com/prometheus/client_golang v1.19.1
|
||||||
github.com/r3labs/diff/v3 v3.0.1
|
|
||||||
github.com/rs/xid v1.3.0
|
github.com/rs/xid v1.3.0
|
||||||
github.com/shirou/gopsutil/v3 v3.24.4
|
github.com/shirou/gopsutil/v3 v3.24.4
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||||
@@ -156,7 +155,7 @@ require (
|
|||||||
github.com/go-text/typesetting v0.1.0 // indirect
|
github.com/go-text/typesetting v0.1.0 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
github.com/google/btree v1.0.1 // indirect
|
github.com/google/btree v1.1.2 // indirect
|
||||||
github.com/google/s2a-go v0.1.7 // indirect
|
github.com/google/s2a-go v0.1.7 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||||
@@ -211,8 +210,6 @@ require (
|
|||||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||||
github.com/vishvananda/netns v0.0.4 // indirect
|
github.com/vishvananda/netns v0.0.4 // indirect
|
||||||
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
|
||||||
github.com/yuin/goldmark v1.7.1 // indirect
|
github.com/yuin/goldmark v1.7.1 // indirect
|
||||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
@@ -231,7 +228,7 @@ require (
|
|||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect
|
||||||
k8s.io/apimachinery v0.26.2 // indirect
|
k8s.io/apimachinery v0.26.2 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -239,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
|||||||
|
|
||||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||||
|
|
||||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
|
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
|
||||||
|
|
||||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||||
|
|
||||||
|
|||||||
22
go.sum
22
go.sum
@@ -297,8 +297,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
|||||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
@@ -521,14 +521,14 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||||
@@ -605,8 +605,6 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
|||||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||||
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
|
|
||||||
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
|
|
||||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||||
@@ -699,10 +697,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
|
|||||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
|
||||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
|
||||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
|
||||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
@@ -1238,8 +1232,8 @@ gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4
|
|||||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
|
||||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||||
|
|||||||
@@ -873,7 +873,7 @@ services:
|
|||||||
zitadel:
|
zitadel:
|
||||||
restart: 'always'
|
restart: 'always'
|
||||||
networks: [netbird]
|
networks: [netbird]
|
||||||
image: 'ghcr.io/zitadel/zitadel:v2.54.10'
|
image: 'ghcr.io/zitadel/zitadel:v2.64.1'
|
||||||
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE'
|
||||||
env_file:
|
env_file:
|
||||||
- ./zitadel.env
|
- ./zitadel.env
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
sysInfo := system.GetInfo(context.TODO())
|
sysInfo := system.GetInfo(context.TODO(), nil)
|
||||||
_, err = client.Login(*key, sysInfo, nil)
|
_, err = client.Login(*key, sysInfo, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expecting err on unregistered login, got nil")
|
t.Error("expecting err on unregistered login, got nil")
|
||||||
@@ -202,7 +202,7 @@ func TestClient_LoginRegistered(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
info := system.GetInfo(context.TODO())
|
info := system.GetInfo(context.TODO(), nil)
|
||||||
resp, err := client.Register(*key, ValidKey, "", info, nil)
|
resp, err := client.Register(*key, ValidKey, "", info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -232,7 +232,7 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
info := system.GetInfo(context.TODO())
|
info := system.GetInfo(context.TODO(), nil)
|
||||||
_, err = client.Register(*serverKey, ValidKey, "", info, nil)
|
_, err = client.Register(*serverKey, ValidKey, "", info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
@@ -248,7 +248,7 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
info = system.GetInfo(context.TODO())
|
info = system.GetInfo(context.TODO(), nil)
|
||||||
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil)
|
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -346,7 +346,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
info := system.GetInfo(context.TODO())
|
info := system.GetInfo(context.TODO(), nil)
|
||||||
_, err = testClient.Register(*key, ValidKey, "", info, nil)
|
_, err = testClient.Register(*key, ValidKey, "", info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error while trying to register client: %v", err)
|
t.Errorf("error while trying to register client: %v", err)
|
||||||
|
|||||||
@@ -110,11 +110,10 @@ type AccountManager interface {
|
|||||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
||||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||||
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
|
||||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
|
||||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
@@ -140,7 +139,7 @@ type AccountManager interface {
|
|||||||
HasConnectedChannel(peerID string) bool
|
HasConnectedChannel(peerID string) bool
|
||||||
GetExternalCacheManager() ExternalCacheManager
|
GetExternalCacheManager() ExternalCacheManager
|
||||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManager() idp.Manager
|
GetIdpManager() idp.Manager
|
||||||
@@ -153,6 +152,7 @@ type AccountManager interface {
|
|||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
|
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
|
||||||
|
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
@@ -965,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UserGroupsAddToPeers adds groups to all peers of user
|
// UserGroupsAddToPeers adds groups to all peers of user
|
||||||
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
|
||||||
|
groupUpdates := make(map[string][]string)
|
||||||
|
|
||||||
userPeers := make(map[string]struct{})
|
userPeers := make(map[string]struct{})
|
||||||
for pid, peer := range a.Peers {
|
for pid, peer := range a.Peers {
|
||||||
if peer.UserID == userID {
|
if peer.UserID == userID {
|
||||||
@@ -979,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldPeers := group.Peers
|
||||||
|
|
||||||
groupPeers := make(map[string]struct{})
|
groupPeers := make(map[string]struct{})
|
||||||
for _, pid := range group.Peers {
|
for _, pid := range group.Peers {
|
||||||
groupPeers[pid] = struct{}{}
|
groupPeers[pid] = struct{}{}
|
||||||
@@ -992,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
|||||||
for pid := range groupPeers {
|
for pid := range groupPeers {
|
||||||
group.Peers = append(group.Peers, pid)
|
group.Peers = append(group.Peers, pid)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupUpdates[gid] = difference(group.Peers, oldPeers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return groupUpdates
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserGroupsRemoveFromPeers removes groups from all peers of user
|
// UserGroupsRemoveFromPeers removes groups from all peers of user
|
||||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
|
||||||
|
groupUpdates := make(map[string][]string)
|
||||||
|
|
||||||
for _, gid := range groups {
|
for _, gid := range groups {
|
||||||
group, ok := a.Groups[gid]
|
group, ok := a.Groups[gid]
|
||||||
if !ok || group.Name == "All" {
|
if !ok || group.Name == "All" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldPeers := group.Peers
|
||||||
|
|
||||||
update := make([]string, 0, len(group.Peers))
|
update := make([]string, 0, len(group.Peers))
|
||||||
for _, pid := range group.Peers {
|
for _, pid := range group.Peers {
|
||||||
peer, ok := a.Peers[pid]
|
peer, ok := a.Peers[pid]
|
||||||
@@ -1013,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
group.Peers = update
|
group.Peers = update
|
||||||
|
groupUpdates[gid] = difference(oldPeers, group.Peers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return groupUpdates
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||||
@@ -1175,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("groups propagation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
updatedAccount := account.UpdateSettings(newSettings)
|
updatedAccount := account.UpdateSettings(newSettings)
|
||||||
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
err = am.Store.SaveAccount(ctx, account)
|
||||||
@@ -1185,21 +1206,39 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return updatedAccount, nil
|
return updatedAccount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
|
||||||
event := activity.AccountPeerInactivityExpirationEnabled
|
if newSettings.GroupsPropagationEnabled {
|
||||||
if !newSettings.PeerInactivityExpirationEnabled {
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
|
||||||
event = activity.AccountPeerInactivityExpirationDisabled
|
// Todo: retroactively add user groups to all peers
|
||||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
|
||||||
} else {
|
} else {
|
||||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
return nil
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
}
|
||||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
|
||||||
|
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||||
|
|
||||||
|
if newSettings.PeerInactivityExpirationEnabled {
|
||||||
|
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||||
|
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
|
||||||
|
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||||
|
event := activity.AccountPeerInactivityExpirationEnabled
|
||||||
|
if !newSettings.PeerInactivityExpirationEnabled {
|
||||||
|
event = activity.AccountPeerInactivityExpirationDisabled
|
||||||
|
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||||
|
} else {
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1248,7 +1287,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed getting account %s expiring peers", account.Id)
|
log.Errorf("failed getting account %s expiring peers", accountID)
|
||||||
return account.GetNextInactivePeerExpiration()
|
return account.GetNextInactivePeerExpiration()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1434,7 +1473,7 @@ func isNil(i idp.Manager) bool {
|
|||||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -2028,7 +2067,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
return fmt.Errorf("error getting user: %w", err)
|
return fmt.Errorf("error getting user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := transaction.GetAccountGroups(ctx, accountID)
|
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error getting account groups: %w", err)
|
return fmt.Errorf("error getting account groups: %w", err)
|
||||||
}
|
}
|
||||||
@@ -2058,7 +2097,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
|
|
||||||
// Propagate changes to peers if group propagation is enabled
|
// Propagate changes to peers if group propagation is enabled
|
||||||
if settings.GroupsPropagationEnabled {
|
if settings.GroupsPropagationEnabled {
|
||||||
groups, err = transaction.GetAccountGroups(ctx, accountID)
|
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error getting account groups: %w", err)
|
return fmt.Errorf("error getting account groups: %w", err)
|
||||||
}
|
}
|
||||||
@@ -2082,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
return fmt.Errorf("error saving groups: %w", err)
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2100,7 +2139,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range addNewGroups {
|
for _, g := range addNewGroups {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
} else {
|
} else {
|
||||||
@@ -2113,7 +2152,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range removeOldGroups {
|
for _, g := range removeOldGroups {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
} else {
|
} else {
|
||||||
@@ -2126,14 +2165,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
if settings.GroupsPropagationEnabled {
|
if settings.GroupsPropagationEnabled {
|
||||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error getting account: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2289,12 +2333,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, status.NewGetAccountError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
|
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
|
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
|
||||||
@@ -2313,12 +2357,12 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return status.NewGetAccountError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -2334,6 +2378,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
|||||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
|
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||||
|
defer unlockPeer()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -2397,12 +2444,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
|
|||||||
|
|
||||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
|
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
|
||||||
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
|
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
|
||||||
updatedAccount, err := am.Store.GetAccount(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
am.updateAccountPeers(ctx, updatedAccount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||||
|
|||||||
@@ -6,13 +6,17 @@ import (
|
|||||||
b64 "encoding/base64"
|
b64 "encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -29,14 +33,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type MocIntegratedValidator struct {
|
type MocIntegratedValidator struct {
|
||||||
|
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||||
return update, nil
|
if a.ValidatePeerFunc != nil {
|
||||||
|
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
|
||||||
|
}
|
||||||
|
return update, false, nil
|
||||||
}
|
}
|
||||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
||||||
validatedPeers := make(map[string]struct{})
|
validatedPeers := make(map[string]struct{})
|
||||||
@@ -978,6 +986,110 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||||
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
|
Domain: "example.com",
|
||||||
|
UserId: "pvt-domain-user",
|
||||||
|
DomainCategory: PrivateCategory,
|
||||||
|
}
|
||||||
|
|
||||||
|
publicClaims := jwtclaims.AuthorizationClaims{
|
||||||
|
Domain: "test.com",
|
||||||
|
UserId: "public-domain-user",
|
||||||
|
DomainCategory: PublicCategory,
|
||||||
|
}
|
||||||
|
|
||||||
|
am, err := createManager(b)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pid, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
users := genUsers("priv", 100)
|
||||||
|
|
||||||
|
acc, err := am.Store.GetAccount(context.Background(), id)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
acc.Users = users
|
||||||
|
|
||||||
|
err = am.Store.SaveAccount(context.Background(), acc)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userP := genUsers("pub", 100)
|
||||||
|
|
||||||
|
pacc, err := am.Store.GetAccount(context.Background(), pid)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pacc.Users = userP
|
||||||
|
|
||||||
|
err = am.Store.SaveAccount(context.Background(), pacc)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("public without account ID", func(b *testing.B) {
|
||||||
|
// b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("private without account ID", func(b *testing.B) {
|
||||||
|
// b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("private with account ID", func(b *testing.B) {
|
||||||
|
claims.AccountId = id
|
||||||
|
// b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func genUsers(p string, n int) map[string]*User {
|
||||||
|
users := map[string]*User{}
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
users[fmt.Sprintf("%s-%d", p, i)] = &User{
|
||||||
|
Id: fmt.Sprintf("%s-%d", p, i),
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
LastLogin: now,
|
||||||
|
CreatedAt: now,
|
||||||
|
Issued: "api",
|
||||||
|
AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return users
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountManager_AddPeer(t *testing.T) {
|
func TestAccountManager_AddPeer(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1010,7 +1122,6 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
expectedPeerKey := key.PublicKey().String()
|
expectedPeerKey := key.PublicKey().String()
|
||||||
expectedSetupKey := setupKey.Key
|
|
||||||
|
|
||||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||||
Key: expectedPeerKey,
|
Key: expectedPeerKey,
|
||||||
@@ -1035,10 +1146,6 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
|||||||
t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String())
|
t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.SetupKey != expectedSetupKey {
|
|
||||||
t.Errorf("expecting just added peer to have SetupKey = %s, got %s", expectedSetupKey, peer.SetupKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
if account.Network.CurrentSerial() != 1 {
|
if account.Network.CurrentSerial() != 1 {
|
||||||
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
|
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
|
||||||
}
|
}
|
||||||
@@ -1135,8 +1242,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policy := Policy{
|
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
ID: "policy",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@@ -1147,8 +1253,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
@@ -1217,19 +1322,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
|||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
|
||||||
policy := Policy{
|
|
||||||
Enabled: true,
|
|
||||||
Rules: []*PolicyRule{
|
|
||||||
{
|
|
||||||
Enabled: true,
|
|
||||||
Sources: []string{"groupA"},
|
|
||||||
Destinations: []string{"groupA"},
|
|
||||||
Bidirectional: true,
|
|
||||||
Action: PolicyTrafficActionAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -1242,7 +1334,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
t.Errorf("delete default rule: %v", err)
|
t.Errorf("delete default rule: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1263,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policy := Policy{
|
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@@ -1274,9 +1378,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
if err != nil {
|
||||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
|
||||||
t.Errorf("save policy: %v", err)
|
t.Errorf("save policy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1310,13 +1413,20 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
|||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
|
||||||
group := group.Group{
|
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||||
ID: "groupA",
|
ID: "groupA",
|
||||||
Name: "GroupA",
|
Name: "GroupA",
|
||||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err, "failed to save group")
|
||||||
|
|
||||||
|
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||||
|
t.Errorf("delete default rule: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policy := Policy{
|
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@@ -1327,14 +1437,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
if err != nil {
|
||||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
|
||||||
t.Errorf("delete default rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
|
||||||
t.Errorf("save policy: %v", err)
|
t.Errorf("save policy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1357,7 +1461,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
|
||||||
t.Errorf("delete group: %v", err)
|
t.Errorf("delete group: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -2367,7 +2471,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
LoginExpired: false,
|
LoginExpired: false,
|
||||||
},
|
},
|
||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
SetupKey: "key",
|
|
||||||
},
|
},
|
||||||
"peer-2": {
|
"peer-2": {
|
||||||
Status: &nbpeer.PeerStatus{
|
Status: &nbpeer.PeerStatus{
|
||||||
@@ -2375,7 +2478,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
LoginExpired: false,
|
LoginExpired: false,
|
||||||
},
|
},
|
||||||
LoginExpirationEnabled: true,
|
LoginExpirationEnabled: true,
|
||||||
SetupKey: "key",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expiration: time.Second,
|
expiration: time.Second,
|
||||||
@@ -2529,7 +2631,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
|||||||
LoginExpired: false,
|
LoginExpired: false,
|
||||||
},
|
},
|
||||||
InactivityExpirationEnabled: true,
|
InactivityExpirationEnabled: true,
|
||||||
SetupKey: "key",
|
|
||||||
},
|
},
|
||||||
"peer-2": {
|
"peer-2": {
|
||||||
Status: &nbpeer.PeerStatus{
|
Status: &nbpeer.PeerStatus{
|
||||||
@@ -2537,7 +2638,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
|||||||
LoginExpired: false,
|
LoginExpired: false,
|
||||||
},
|
},
|
||||||
InactivityExpirationEnabled: true,
|
InactivityExpirationEnabled: true,
|
||||||
SetupKey: "key",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expiration: time.Second,
|
expiration: time.Second,
|
||||||
@@ -2615,7 +2715,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 0)
|
assert.Len(t, user.AutoGroups, 0)
|
||||||
|
|
||||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||||
assert.NoError(t, err, "unable to get group")
|
assert.NoError(t, err, "unable to get group")
|
||||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
@@ -2635,7 +2735,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 1)
|
assert.Len(t, user.AutoGroups, 1)
|
||||||
|
|
||||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||||
assert.NoError(t, err, "unable to get group")
|
assert.NoError(t, err, "unable to get group")
|
||||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
@@ -2674,7 +2774,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
assert.NoError(t, err, "unable to sync jwt groups")
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
|
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, groups, 3, "new group3 should be added")
|
assert.Len(t, groups, 3, "new group3 should be added")
|
||||||
|
|
||||||
@@ -2886,3 +2986,218 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage)
|
|||||||
t.Error("Timed out waiting for update message")
|
t.Error("Timed out waiting for update message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||||
|
benchCases := []struct {
|
||||||
|
name string
|
||||||
|
peers int
|
||||||
|
groups int
|
||||||
|
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||||
|
minMsPerOpLocal float64
|
||||||
|
maxMsPerOpLocal float64
|
||||||
|
minMsPerOpCICD float64
|
||||||
|
maxMsPerOpCICD float64
|
||||||
|
}{
|
||||||
|
{"Small", 50, 5, 1, 3, 4, 10},
|
||||||
|
{"Medium", 500, 100, 7, 13, 10, 60},
|
||||||
|
{"Large", 5000, 200, 65, 80, 60, 170},
|
||||||
|
{"Small single", 50, 10, 1, 3, 4, 60},
|
||||||
|
{"Medium single", 500, 10, 7, 13, 10, 26},
|
||||||
|
{"Large 5", 5000, 15, 65, 80, 60, 170},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to get account: %v", err)
|
||||||
|
}
|
||||||
|
peerChannels := make(map[string]chan *UpdateMessage)
|
||||||
|
for peerID := range account.Peers {
|
||||||
|
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||||
|
}
|
||||||
|
manager.peersUpdateManager.peerChannels = peerChannels
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||||
|
assert.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||||
|
b.ReportMetric(msPerOp, "ms/op")
|
||||||
|
|
||||||
|
minExpected := bc.minMsPerOpLocal
|
||||||
|
maxExpected := bc.maxMsPerOpLocal
|
||||||
|
if os.Getenv("CI") == "true" {
|
||||||
|
minExpected = bc.minMsPerOpCICD
|
||||||
|
maxExpected = bc.maxMsPerOpCICD
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp < minExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp > maxExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||||
|
benchCases := []struct {
|
||||||
|
name string
|
||||||
|
peers int
|
||||||
|
groups int
|
||||||
|
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||||
|
minMsPerOpLocal float64
|
||||||
|
maxMsPerOpLocal float64
|
||||||
|
minMsPerOpCICD float64
|
||||||
|
maxMsPerOpCICD float64
|
||||||
|
}{
|
||||||
|
{"Small", 50, 5, 102, 110, 102, 120},
|
||||||
|
{"Medium", 500, 100, 105, 140, 105, 170},
|
||||||
|
{"Large", 5000, 200, 160, 200, 160, 270},
|
||||||
|
{"Small single", 50, 10, 102, 110, 102, 120},
|
||||||
|
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||||
|
{"Large 5", 5000, 15, 160, 200, 160, 270},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to get account: %v", err)
|
||||||
|
}
|
||||||
|
peerChannels := make(map[string]chan *UpdateMessage)
|
||||||
|
for peerID := range account.Peers {
|
||||||
|
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||||
|
}
|
||||||
|
manager.peersUpdateManager.peerChannels = peerChannels
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
||||||
|
WireGuardPubKey: account.Peers["peer-1"].Key,
|
||||||
|
SSHKey: "someKey",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||||
|
UserID: "regular_user",
|
||||||
|
SetupKey: "",
|
||||||
|
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||||
|
})
|
||||||
|
assert.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||||
|
b.ReportMetric(msPerOp, "ms/op")
|
||||||
|
|
||||||
|
minExpected := bc.minMsPerOpLocal
|
||||||
|
maxExpected := bc.maxMsPerOpLocal
|
||||||
|
if os.Getenv("CI") == "true" {
|
||||||
|
minExpected = bc.minMsPerOpCICD
|
||||||
|
maxExpected = bc.maxMsPerOpCICD
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp < minExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp > maxExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||||
|
benchCases := []struct {
|
||||||
|
name string
|
||||||
|
peers int
|
||||||
|
groups int
|
||||||
|
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||||
|
minMsPerOpLocal float64
|
||||||
|
maxMsPerOpLocal float64
|
||||||
|
minMsPerOpCICD float64
|
||||||
|
maxMsPerOpCICD float64
|
||||||
|
}{
|
||||||
|
{"Small", 50, 5, 107, 120, 107, 140},
|
||||||
|
{"Medium", 500, 100, 105, 140, 105, 170},
|
||||||
|
{"Large", 5000, 200, 180, 220, 180, 320},
|
||||||
|
{"Small single", 50, 10, 107, 120, 105, 140},
|
||||||
|
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||||
|
{"Large 5", 5000, 15, 180, 220, 180, 320},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to get account: %v", err)
|
||||||
|
}
|
||||||
|
peerChannels := make(map[string]chan *UpdateMessage)
|
||||||
|
for peerID := range account.Peers {
|
||||||
|
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||||
|
}
|
||||||
|
manager.peersUpdateManager.peerChannels = peerChannels
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
||||||
|
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
|
||||||
|
SSHKey: "someKey",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||||
|
UserID: "regular_user",
|
||||||
|
SetupKey: "",
|
||||||
|
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||||
|
})
|
||||||
|
assert.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||||
|
b.ReportMetric(msPerOp, "ms/op")
|
||||||
|
|
||||||
|
minExpected := bc.minMsPerOpLocal
|
||||||
|
maxExpected := bc.maxMsPerOpLocal
|
||||||
|
if os.Getenv("CI") == "true" {
|
||||||
|
minExpected = bc.minMsPerOpCICD
|
||||||
|
maxExpected = bc.maxMsPerOpCICD
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp < minExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp > maxExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -146,6 +146,11 @@ const (
|
|||||||
AccountPeerInactivityExpirationEnabled Activity = 65
|
AccountPeerInactivityExpirationEnabled Activity = 65
|
||||||
AccountPeerInactivityExpirationDisabled Activity = 66
|
AccountPeerInactivityExpirationDisabled Activity = 66
|
||||||
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||||
|
|
||||||
|
SetupKeyDeleted Activity = 68
|
||||||
|
|
||||||
|
UserGroupPropagationEnabled Activity = 69
|
||||||
|
UserGroupPropagationDisabled Activity = 70
|
||||||
)
|
)
|
||||||
|
|
||||||
var activityMap = map[Activity]Code{
|
var activityMap = map[Activity]Code{
|
||||||
@@ -219,6 +224,10 @@ var activityMap = map[Activity]Code{
|
|||||||
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
|
AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"},
|
||||||
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||||
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||||
|
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
|
||||||
|
|
||||||
|
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
|
||||||
|
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
package differs
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/r3labs/diff/v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NetIPAddr is a custom differ for netip.Addr
|
|
||||||
type NetIPAddr struct {
|
|
||||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
|
|
||||||
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
|
||||||
if a.Kind() == reflect.Invalid {
|
|
||||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if b.Kind() == reflect.Invalid {
|
|
||||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fromAddr, ok1 := a.Interface().(netip.Addr)
|
|
||||||
toAddr, ok2 := b.Interface().(netip.Addr)
|
|
||||||
if !ok1 || !ok2 {
|
|
||||||
return fmt.Errorf("invalid type for netip.Addr")
|
|
||||||
}
|
|
||||||
|
|
||||||
if fromAddr.String() != toAddr.String() {
|
|
||||||
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
|
||||||
differ.DiffFunc = dfunc //nolint
|
|
||||||
}
|
|
||||||
|
|
||||||
// NetIPPrefix is a custom differ for netip.Prefix
|
|
||||||
type NetIPPrefix struct {
|
|
||||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
|
|
||||||
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
|
||||||
if a.Kind() == reflect.Invalid {
|
|
||||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if b.Kind() == reflect.Invalid {
|
|
||||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fromPrefix, ok1 := a.Interface().(netip.Prefix)
|
|
||||||
toPrefix, ok2 := b.Interface().(netip.Prefix)
|
|
||||||
if !ok1 || !ok2 {
|
|
||||||
return fmt.Errorf("invalid type for netip.Addr")
|
|
||||||
}
|
|
||||||
|
|
||||||
if fromPrefix.String() != toPrefix.String() {
|
|
||||||
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
|
||||||
differ.DiffFunc = dfunc //nolint
|
|
||||||
}
|
|
||||||
@@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||||
@@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
|||||||
|
|
||||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
|
||||||
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
|
|
||||||
}
|
|
||||||
|
|
||||||
if dnsSettingsToSave == nil {
|
if dnsSettingsToSave == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
|
if err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
oldSettings := account.DNSSettings.Copy()
|
|
||||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
|
||||||
|
|
||||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
|
||||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, id := range addedGroups {
|
if user.AccountID != accountID {
|
||||||
group := account.GetGroup(id)
|
return status.NewUserNotPartOfAccountError()
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, id := range removedGroups {
|
if !user.HasAdminPower() {
|
||||||
group := account.GetGroup(id)
|
return status.NewAdminPermissionError()
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
var updateAccountPeers bool
|
||||||
am.updateAccountPeers(ctx, account)
|
var eventsToStore []func()
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||||
|
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||||
|
|
||||||
|
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||||
|
eventsToStore = append(eventsToStore, events...)
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, storeEvent := range eventsToStore {
|
||||||
|
storeEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
|
||||||
|
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
|
||||||
|
var eventsToStore []func()
|
||||||
|
|
||||||
|
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, groupID := range addedGroups {
|
||||||
|
group, ok := groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, groupID := range removedGroups {
|
||||||
|
group, ok := groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return eventsToStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||||
|
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||||
|
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPeers {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDNSSettings validates the DNS settings.
|
||||||
|
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
|
||||||
|
if len(settings.DisabledManagementGroups) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return validateGroups(settings.DisabledManagementGroups, groups)
|
||||||
|
}
|
||||||
|
|
||||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||||
protoUpdate := &proto.DNSConfig{
|
protoUpdate := &proto.DNSConfig{
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -521,23 +522,64 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
// Creating DNS settings with groups that have no peers should not update account peers or send peer update
|
||||||
ID: "groupA",
|
t.Run("creating dns setting with unused groups", func(t *testing.T) {
|
||||||
Name: "GroupA",
|
done := make(chan struct{})
|
||||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
go func() {
|
||||||
})
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
assert.NoError(t, err)
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
_, err = manager.CreateNameServerGroup(
|
_, err = manager.CreateNameServerGroup(
|
||||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{
|
||||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||||
NSType: dns.UDPNameServerType,
|
NSType: dns.UDPNameServerType,
|
||||||
Port: dns.DefaultDNSPort,
|
Port: dns.DefaultDNSPort,
|
||||||
}},
|
}},
|
||||||
[]string{"groupA"},
|
[]string{"groupB"},
|
||||||
true, []string{}, true, userID, false,
|
true, []string{}, true, userID, false,
|
||||||
)
|
)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Creating DNS settings with groups that have peers should update account peers and send peer update
|
||||||
|
t.Run("creating dns setting with used groups", func(t *testing.T) {
|
||||||
|
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = manager.CreateNameServerGroup(
|
||||||
|
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||||
|
NSType: dns.UDPNameServerType,
|
||||||
|
Port: dns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
[]string{"groupA"},
|
||||||
|
true, []string{}, true, userID, false,
|
||||||
|
)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Saving DNS settings with groups that have peers should update account peers and send peer update
|
// Saving DNS settings with groups that have peers should update account peers and send peer update
|
||||||
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||||
@@ -559,27 +601,6 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
|
|
||||||
// since there is no change in the network map
|
|
||||||
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
|
||||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
|
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
|
||||||
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
|||||||
// It is recommended to call it with locking FileStore.mux
|
// It is recommended to call it with locking FileStore.mux
|
||||||
func (s *FileStore) persist(ctx context.Context, file string) error {
|
func (s *FileStore) persist(ctx context.Context, file string) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
err := util.WriteJson(file, s)
|
err := util.WriteJson(context.Background(), file, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@@ -27,18 +28,17 @@ func (e *GroupLinkError) Error() string {
|
|||||||
|
|
||||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
return status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -49,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
|||||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllGroups returns all groups in an account
|
// GetAllGroups returns all groups in an account
|
||||||
@@ -58,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
|||||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||||
return am.Store.GetAccountGroups(ctx, accountID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGroup object of the peers
|
// SaveGroup object of the peers
|
||||||
@@ -77,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
|||||||
// SaveGroups adds new groups to the account.
|
// SaveGroups adds new groups to the account.
|
||||||
// Note: This function does not acquire the global lock.
|
// Note: This function does not acquire the global lock.
|
||||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
|
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
|
var groupsToSave []*nbgroup.Group
|
||||||
|
var updateAccountPeers bool
|
||||||
|
|
||||||
for _, newGroup := range newGroups {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
groupIDs := make([]string, 0, len(groups))
|
||||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
for _, newGroup := range groups {
|
||||||
}
|
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||||
|
return err
|
||||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
|
||||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
|
||||||
if err != nil {
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok || s.ErrorType != status.NotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Avoid duplicate groups only for the API issued groups.
|
newGroup.AccountID = accountID
|
||||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
groupsToSave = append(groupsToSave, newGroup)
|
||||||
if existingGroup != nil {
|
groupIDs = append(groupIDs, newGroup.ID)
|
||||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
newGroup.ID = xid.New().String()
|
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||||
|
eventsToStore = append(eventsToStore, events...)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peerID := range newGroup.Peers {
|
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||||
if account.Peers[peerID] == nil {
|
if err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
return err
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
oldGroup := account.Groups[newGroup.ID]
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
account.Groups[newGroup.ID] = newGroup
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
|
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
|
||||||
eventsToStore = append(eventsToStore, events...)
|
})
|
||||||
}
|
if err != nil {
|
||||||
|
|
||||||
newGroupIDs := make([]string, 0, len(newGroups))
|
|
||||||
for _, newGroup := range newGroups {
|
|
||||||
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, storeEvent := range eventsToStore {
|
for _, storeEvent := range eventsToStore {
|
||||||
storeEvent()
|
storeEvent()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
|
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
|
|
||||||
addedPeers := make([]string, 0)
|
addedPeers := make([]string, 0)
|
||||||
removedPeers := make([]string, 0)
|
removedPeers := make([]string, 0)
|
||||||
|
|
||||||
if oldGroup != nil {
|
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||||
|
if err == nil && oldGroup != nil {
|
||||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||||
} else {
|
} else {
|
||||||
@@ -159,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range addedPeers {
|
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||||
peer := account.Peers[p]
|
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
|
||||||
if peer == nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peerID := range addedPeers {
|
||||||
|
peer, ok := peers[peerID]
|
||||||
|
if !ok {
|
||||||
|
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
peerCopy := peer // copy to avoid closure issues
|
|
||||||
eventsToStore = append(eventsToStore, func() {
|
eventsToStore = append(eventsToStore, func() {
|
||||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
|
meta := map[string]any{
|
||||||
map[string]any{
|
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
}
|
||||||
})
|
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range removedPeers {
|
for _, peerID := range removedPeers {
|
||||||
peer := account.Peers[p]
|
peer, ok := peers[peerID]
|
||||||
if peer == nil {
|
if !ok {
|
||||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
peerCopy := peer // copy to avoid closure issues
|
|
||||||
eventsToStore = append(eventsToStore, func() {
|
eventsToStore = append(eventsToStore, func() {
|
||||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
|
meta := map[string]any{
|
||||||
map[string]any{
|
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
}
|
||||||
})
|
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,42 +210,10 @@ func difference(a, b []string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroup object of the peers.
|
// DeleteGroup object of the peers.
|
||||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
|
||||||
account, err := am.Store.GetAccount(ctx, accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
|
||||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
delete(account.Groups, groupID)
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroups deletes groups from an account.
|
// DeleteGroups deletes groups from an account.
|
||||||
@@ -254,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
|||||||
//
|
//
|
||||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||||
// Errors are collected and returned at the end.
|
// Errors are collected and returned at the end.
|
||||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||||
account, err := am.Store.GetAccount(ctx, accountId)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
var allErrors error
|
var allErrors error
|
||||||
|
var groupIDsToDelete []string
|
||||||
|
var deletedGroups []*nbgroup.Group
|
||||||
|
|
||||||
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
group, ok := account.Groups[groupID]
|
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
continue
|
allErrors = errors.Join(allErrors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||||
|
allErrors = errors.Join(allErrors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||||
|
deletedGroups = append(deletedGroups, group)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateDeleteGroup(account, group, userId); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
return err
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(account.Groups, groupID)
|
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
|
||||||
deletedGroups = append(deletedGroups, group)
|
})
|
||||||
}
|
if err != nil {
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range deletedGroups {
|
for _, group := range deletedGroups {
|
||||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||||
}
|
}
|
||||||
|
|
||||||
return allErrors
|
return allErrors
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListGroups objects of the peers
|
|
||||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
|
||||||
for _, item := range account.Groups {
|
|
||||||
groups = append(groups, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
return groups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupAddPeer appends peer to the group
|
// GroupAddPeer appends peer to the group
|
||||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
var group *nbgroup.Group
|
||||||
|
var updateAccountPeers bool
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated := group.AddPeer(peerID); !updated {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
if updateAccountPeers {
|
||||||
if !ok {
|
am.updateAccountPeers(ctx, accountID)
|
||||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
|
||||||
}
|
|
||||||
|
|
||||||
add := true
|
|
||||||
for _, itemID := range group.Peers {
|
|
||||||
if itemID == peerID {
|
|
||||||
add = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if add {
|
|
||||||
group.Peers = append(group.Peers, peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -351,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
|||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
var group *nbgroup.Group
|
||||||
|
var updateAccountPeers bool
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated := group.RemovePeer(peerID); !updated {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
if updateAccountPeers {
|
||||||
if !ok {
|
am.updateAccountPeers(ctx, accountID)
|
||||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
for i, itemID := range group.Peers {
|
|
||||||
if itemID == peerID {
|
|
||||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
|
// validateNewGroup validates the new group for existence and required fields.
|
||||||
|
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
|
||||||
|
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||||
|
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||||
|
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent duplicate groups for API-issued groups.
|
||||||
|
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||||
|
if existingGroup != nil {
|
||||||
|
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
newGroup.ID = xid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peerID := range newGroup.Peers {
|
||||||
|
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
|
||||||
// disable a deleting integration group if the initiator is not an admin service user
|
// disable a deleting integration group if the initiator is not an admin service user
|
||||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||||
executingUser := account.Users[userID]
|
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if executingUser == nil {
|
if err != nil {
|
||||||
return status.Errorf(status.NotFound, "user not found")
|
return err
|
||||||
}
|
}
|
||||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
|
if group.IsGroupAll() {
|
||||||
|
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
|
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"name server groups", linkedDns.Name}
|
return &GroupLinkError{"name server groups", linkedDns.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
|
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"policy", linkedPolicy.Name}
|
return &GroupLinkError{"policy", linkedPolicy.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
|
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
|
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"user", linkedUser.Id}
|
return &GroupLinkError{"user", linkedUser.Id}
|
||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
|
return checkGroupLinkedToSettings(ctx, transaction, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||||
|
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
|
||||||
|
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
|
||||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.Extra != nil {
|
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||||
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
|
if err != nil {
|
||||||
return &GroupLinkError{"integrated validator", group.Name}
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||||
|
return &GroupLinkError{"integrated validator", group.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||||
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
|
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
|
||||||
|
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||||
return true, r
|
return true, r
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||||
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
|
||||||
|
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, policy := range policies {
|
for _, policy := range policies {
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
||||||
@@ -446,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||||
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
|
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||||
|
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, dns := range nameServerGroups {
|
for _, dns := range nameServerGroups {
|
||||||
for _, g := range dns.Groups {
|
for _, g := range dns.Groups {
|
||||||
if g == groupID {
|
if g == groupID {
|
||||||
@@ -454,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||||
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
|
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
|
||||||
|
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, setupKey := range setupKeys {
|
for _, setupKey := range setupKeys {
|
||||||
if slices.Contains(setupKey.AutoGroups, groupID) {
|
if slices.Contains(setupKey.AutoGroups, groupID) {
|
||||||
return true, setupKey
|
return true, setupKey
|
||||||
@@ -468,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||||
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
|
||||||
|
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
if slices.Contains(user.AutoGroups, groupID) {
|
if slices.Contains(user.AutoGroups, groupID) {
|
||||||
return true, user
|
return true, user
|
||||||
@@ -477,8 +537,36 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||||
return true
|
return true
|
||||||
@@ -487,21 +575,18 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||||
for _, groupID := range groupIDs {
|
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
|
||||||
return true
|
if err != nil {
|
||||||
}
|
return false, err
|
||||||
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
}
|
||||||
return true
|
|
||||||
}
|
for _, group := range groups {
|
||||||
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
if group.HasPeers() {
|
||||||
return true
|
return true, nil
|
||||||
}
|
|
||||||
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,3 +49,35 @@ func (g *Group) Copy() *Group {
|
|||||||
func (g *Group) HasPeers() bool {
|
func (g *Group) HasPeers() bool {
|
||||||
return len(g.Peers) > 0
|
return len(g.Peers) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsGroupAll checks if the group is a default "All" group.
|
||||||
|
func (g *Group) IsGroupAll() bool {
|
||||||
|
return g.Name == "All"
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeer adds peerID to Peers if not present, returning true if added.
|
||||||
|
func (g *Group) AddPeer(peerID string) bool {
|
||||||
|
if peerID == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, itemID := range g.Peers {
|
||||||
|
if itemID == peerID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
g.Peers = append(g.Peers, peerID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer removes peerID from Peers if present, returning true if removed.
|
||||||
|
func (g *Group) RemovePeer(peerID string) bool {
|
||||||
|
for i, itemID := range g.Peers {
|
||||||
|
if itemID == peerID {
|
||||||
|
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
90
management/server/group/group_test.go
Normal file
90
management/server/group/group_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package group
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAddPeer(t *testing.T) {
|
||||||
|
t.Run("add new peer to empty slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{}}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.True(t, group.AddPeer(peerID))
|
||||||
|
assert.Contains(t, group.Peers, peerID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add new peer to nil slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: nil}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.True(t, group.AddPeer(peerID))
|
||||||
|
assert.Contains(t, group.Peers, peerID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add new peer to non-empty slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := "peer3"
|
||||||
|
assert.True(t, group.AddPeer(peerID))
|
||||||
|
assert.Contains(t, group.Peers, peerID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add duplicate peer", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.False(t, group.AddPeer(peerID))
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add empty peer", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := ""
|
||||||
|
assert.False(t, group.AddPeer(peerID))
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemovePeer(t *testing.T) {
|
||||||
|
t.Run("remove existing peer from slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
|
||||||
|
peerID := "peer2"
|
||||||
|
assert.True(t, group.RemovePeer(peerID))
|
||||||
|
assert.NotContains(t, group.Peers, peerID)
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove peer from empty slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{}}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.False(t, group.RemovePeer(peerID))
|
||||||
|
assert.Equal(t, 0, len(group.Peers))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove peer from nil slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: nil}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.False(t, group.RemovePeer(peerID))
|
||||||
|
assert.Nil(t, group.Peers)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove non-existent peer", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := "peer3"
|
||||||
|
assert.False(t, group.RemovePeer(peerID))
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove peer from single-item slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1"}}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.True(t, group.RemovePeer(peerID))
|
||||||
|
assert.Equal(t, 0, len(group.Peers))
|
||||||
|
assert.NotContains(t, group.Peers, peerID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove empty peer", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := ""
|
||||||
|
assert.False(t, group.RemovePeer(peerID))
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -8,12 +8,13 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -207,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "delete non-existent group",
|
name: "delete non-existent group",
|
||||||
groupIDs: []string{"non-existent-group"},
|
groupIDs: []string{"non-existent-group"},
|
||||||
expectedDeleted: []string{"non-existent-group"},
|
expectedReasons: []string{"group: non-existent-group not found"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "delete multiple groups with mixed results",
|
name: "delete multiple groups with mixed results",
|
||||||
@@ -499,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// adding a group to policy
|
// adding a group to policy
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
ID: "policy",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@@ -511,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, false)
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Saving a group linked to policy should update account peers and send peer update
|
// Saving a group linked to policy should update account peers and send peer update
|
||||||
@@ -536,29 +536,6 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Saving an unchanged group should trigger account peers update and not send peer update
|
|
||||||
// since there is no change in the network map
|
|
||||||
t.Run("saving unchanged group", func(t *testing.T) {
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
|
||||||
ID: "groupA",
|
|
||||||
Name: "GroupA",
|
|
||||||
Peers: []string{peer1.ID, peer2.ID},
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// adding peer to a used group should update account peers and send peer update
|
// adding peer to a used group should update account peers and send peer update
|
||||||
t.Run("adding peer to linked group", func(t *testing.T) {
|
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
pb "github.com/golang/protobuf/proto" // nolint
|
pb "github.com/golang/protobuf/proto" // nolint
|
||||||
@@ -38,6 +39,7 @@ type GRPCServer struct {
|
|||||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
appMetrics telemetry.AppMetrics
|
appMetrics telemetry.AppMetrics
|
||||||
ephemeralManager *EphemeralManager
|
ephemeralManager *EphemeralManager
|
||||||
|
peerLocks sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Management server
|
// NewServer creates a new Management server
|
||||||
@@ -148,6 +150,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||||
|
|
||||||
|
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||||
|
defer func() {
|
||||||
|
if unlock != nil {
|
||||||
|
unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
@@ -171,6 +180,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
|
|
||||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||||
return mapError(ctx, err)
|
return mapError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,11 +200,15 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unlock()
|
||||||
|
unlock = nil
|
||||||
|
|
||||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
|
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// condition when there are some updates
|
// condition when there are some updates
|
||||||
@@ -245,10 +259,18 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||||
|
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||||
|
}
|
||||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||||
s.secretsManager.CancelRefresh(peer.ID)
|
s.secretsManager.CancelRefresh(peer.ID)
|
||||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
|
||||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||||
|
|
||||||
|
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||||
@@ -274,6 +296,24 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
|||||||
return claims.UserId, nil
|
return claims.UserId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||||
|
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||||
|
mtx := value.(*sync.RWMutex)
|
||||||
|
mtx.Lock()
|
||||||
|
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||||
|
start = time.Now()
|
||||||
|
|
||||||
|
unlock = func() {
|
||||||
|
mtx.Unlock()
|
||||||
|
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||||
|
}
|
||||||
|
|
||||||
|
return unlock
|
||||||
|
}
|
||||||
|
|
||||||
// maps internal internalStatus.Error to gRPC status.Error
|
// maps internal internalStatus.Error to gRPC status.Error
|
||||||
func mapError(ctx context.Context, err error) error {
|
func mapError(ctx context.Context, err error) error {
|
||||||
if e, ok := internalStatus.FromError(err); ok {
|
if e, ok := internalStatus.FromError(err); ok {
|
||||||
|
|||||||
@@ -439,17 +439,13 @@ components:
|
|||||||
example: 5
|
example: 5
|
||||||
required:
|
required:
|
||||||
- accessible_peers_count
|
- accessible_peers_count
|
||||||
SetupKey:
|
SetupKeyBase:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
id:
|
id:
|
||||||
description: Setup Key ID
|
description: Setup Key ID
|
||||||
type: string
|
type: string
|
||||||
example: 2531583362
|
example: 2531583362
|
||||||
key:
|
|
||||||
description: Setup Key value
|
|
||||||
type: string
|
|
||||||
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
|
||||||
name:
|
name:
|
||||||
description: Setup key name identifier
|
description: Setup key name identifier
|
||||||
type: string
|
type: string
|
||||||
@@ -518,23 +514,31 @@ components:
|
|||||||
- updated_at
|
- updated_at
|
||||||
- usage_limit
|
- usage_limit
|
||||||
- ephemeral
|
- ephemeral
|
||||||
|
SetupKeyClear:
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/components/schemas/SetupKeyBase'
|
||||||
|
- type: object
|
||||||
|
properties:
|
||||||
|
key:
|
||||||
|
description: Setup Key as plain text
|
||||||
|
type: string
|
||||||
|
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
||||||
|
required:
|
||||||
|
- key
|
||||||
|
SetupKey:
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/components/schemas/SetupKeyBase'
|
||||||
|
- type: object
|
||||||
|
properties:
|
||||||
|
key:
|
||||||
|
description: Setup Key as secret
|
||||||
|
type: string
|
||||||
|
example: A6160****
|
||||||
|
required:
|
||||||
|
- key
|
||||||
SetupKeyRequest:
|
SetupKeyRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
name:
|
|
||||||
description: Setup Key name
|
|
||||||
type: string
|
|
||||||
example: Default key
|
|
||||||
type:
|
|
||||||
description: Setup key type, one-off for single time usage and reusable
|
|
||||||
type: string
|
|
||||||
example: reusable
|
|
||||||
expires_in:
|
|
||||||
description: Expiration time in seconds
|
|
||||||
type: integer
|
|
||||||
minimum: 86400
|
|
||||||
maximum: 31536000
|
|
||||||
example: 86400
|
|
||||||
revoked:
|
revoked:
|
||||||
description: Setup key revocation status
|
description: Setup key revocation status
|
||||||
type: boolean
|
type: boolean
|
||||||
@@ -545,21 +549,9 @@ components:
|
|||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
example: "ch8i4ug6lnn4g9hqv7m0"
|
example: "ch8i4ug6lnn4g9hqv7m0"
|
||||||
usage_limit:
|
|
||||||
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
|
||||||
type: integer
|
|
||||||
example: 0
|
|
||||||
ephemeral:
|
|
||||||
description: Indicate that the peer will be ephemeral or not
|
|
||||||
type: boolean
|
|
||||||
example: true
|
|
||||||
required:
|
required:
|
||||||
- name
|
|
||||||
- type
|
|
||||||
- expires_in
|
|
||||||
- revoked
|
- revoked
|
||||||
- auto_groups
|
- auto_groups
|
||||||
- usage_limit
|
|
||||||
CreateSetupKeyRequest:
|
CreateSetupKeyRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -1944,7 +1936,7 @@ paths:
|
|||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/SetupKey'
|
$ref: '#/components/schemas/SetupKeyClear'
|
||||||
'400':
|
'400':
|
||||||
"$ref": "#/components/responses/bad_request"
|
"$ref": "#/components/responses/bad_request"
|
||||||
'401':
|
'401':
|
||||||
@@ -2018,6 +2010,32 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
delete:
|
||||||
|
summary: Delete a Setup Key
|
||||||
|
description: Delete a Setup Key
|
||||||
|
tags: [ Setup Keys ]
|
||||||
|
security:
|
||||||
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
name: keyId
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: The unique identifier of a setup key
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Delete status code
|
||||||
|
content: { }
|
||||||
|
'400':
|
||||||
|
"$ref": "#/components/responses/bad_request"
|
||||||
|
'401':
|
||||||
|
"$ref": "#/components/responses/requires_authentication"
|
||||||
|
'403':
|
||||||
|
"$ref": "#/components/responses/forbidden"
|
||||||
|
'500':
|
||||||
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/groups:
|
/api/groups:
|
||||||
get:
|
get:
|
||||||
summary: List all Groups
|
summary: List all Groups
|
||||||
|
|||||||
@@ -1062,7 +1062,94 @@ type SetupKey struct {
|
|||||||
// Id Setup Key ID
|
// Id Setup Key ID
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
|
|
||||||
// Key Setup Key value
|
// Key Setup Key as secret
|
||||||
|
Key string `json:"key"`
|
||||||
|
|
||||||
|
// LastUsed Setup key last usage date
|
||||||
|
LastUsed time.Time `json:"last_used"`
|
||||||
|
|
||||||
|
// Name Setup key name identifier
|
||||||
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
// Revoked Setup key revocation status
|
||||||
|
Revoked bool `json:"revoked"`
|
||||||
|
|
||||||
|
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||||
|
State string `json:"state"`
|
||||||
|
|
||||||
|
// Type Setup key type, one-off for single time usage and reusable
|
||||||
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// UpdatedAt Setup key last update date
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|
||||||
|
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||||
|
UsageLimit int `json:"usage_limit"`
|
||||||
|
|
||||||
|
// UsedTimes Usage count of setup key
|
||||||
|
UsedTimes int `json:"used_times"`
|
||||||
|
|
||||||
|
// Valid Setup key validity status
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupKeyBase defines model for SetupKeyBase.
|
||||||
|
type SetupKeyBase struct {
|
||||||
|
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||||
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
|
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||||
|
Ephemeral bool `json:"ephemeral"`
|
||||||
|
|
||||||
|
// Expires Setup Key expiration date
|
||||||
|
Expires time.Time `json:"expires"`
|
||||||
|
|
||||||
|
// Id Setup Key ID
|
||||||
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// LastUsed Setup key last usage date
|
||||||
|
LastUsed time.Time `json:"last_used"`
|
||||||
|
|
||||||
|
// Name Setup key name identifier
|
||||||
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
// Revoked Setup key revocation status
|
||||||
|
Revoked bool `json:"revoked"`
|
||||||
|
|
||||||
|
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||||
|
State string `json:"state"`
|
||||||
|
|
||||||
|
// Type Setup key type, one-off for single time usage and reusable
|
||||||
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// UpdatedAt Setup key last update date
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|
||||||
|
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||||
|
UsageLimit int `json:"usage_limit"`
|
||||||
|
|
||||||
|
// UsedTimes Usage count of setup key
|
||||||
|
UsedTimes int `json:"used_times"`
|
||||||
|
|
||||||
|
// Valid Setup key validity status
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupKeyClear defines model for SetupKeyClear.
|
||||||
|
type SetupKeyClear struct {
|
||||||
|
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||||
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
|
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||||
|
Ephemeral bool `json:"ephemeral"`
|
||||||
|
|
||||||
|
// Expires Setup Key expiration date
|
||||||
|
Expires time.Time `json:"expires"`
|
||||||
|
|
||||||
|
// Id Setup Key ID
|
||||||
|
Id string `json:"id"`
|
||||||
|
|
||||||
|
// Key Setup Key as plain text
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
|
|
||||||
// LastUsed Setup key last usage date
|
// LastUsed Setup key last usage date
|
||||||
@@ -1098,23 +1185,8 @@ type SetupKeyRequest struct {
|
|||||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||||
AutoGroups []string `json:"auto_groups"`
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
|
||||||
Ephemeral *bool `json:"ephemeral,omitempty"`
|
|
||||||
|
|
||||||
// ExpiresIn Expiration time in seconds
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
|
|
||||||
// Name Setup Key name
|
|
||||||
Name string `json:"name"`
|
|
||||||
|
|
||||||
// Revoked Setup key revocation status
|
// Revoked Setup key revocation status
|
||||||
Revoked bool `json:"revoked"`
|
Revoked bool `json:"revoked"`
|
||||||
|
|
||||||
// Type Setup key type, one-off for single time usage and reusable
|
|
||||||
Type string `json:"type"`
|
|
||||||
|
|
||||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
|
||||||
UsageLimit int `json:"usage_limit"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// User defines model for User.
|
// User defines model for User.
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() {
|
|||||||
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
|
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
|
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
|
||||||
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
||||||
|
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addPoliciesEndpoint() {
|
func (apiHandler *apiHandler) addPoliciesEndpoint() {
|
||||||
|
|||||||
@@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
dnsDomain := h.accountManager.GetDNSDomain()
|
dnsDomain := h.accountManager.GetDNSDomain()
|
||||||
|
|
||||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
for _, peer := range account.Peers {
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsMap := map[string]*nbgroup.Group{}
|
||||||
|
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
for _, group := range groups {
|
||||||
|
groupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||||
|
for _, peer := range peers {
|
||||||
peerToReturn, err := h.checkPeerStatus(peer)
|
peerToReturn, err := h.checkPeerStatus(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
|
||||||
|
|
||||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||||
}
|
}
|
||||||
@@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||||
var groupsInfo []api.GroupMinimum
|
groupsInfo := []api.GroupMinimum{}
|
||||||
groupsChecked := make(map[string]struct{})
|
groupsChecked := make(map[string]struct{})
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
_, ok := groupsChecked[group.ID]
|
_, ok := groupsChecked[group.ID]
|
||||||
|
|||||||
@@ -13,12 +13,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
@@ -168,7 +169,6 @@ func TestGetPeers(t *testing.T) {
|
|||||||
peer := &nbpeer.Peer{
|
peer := &nbpeer.Peer{
|
||||||
ID: testPeerID,
|
ID: testPeerID,
|
||||||
Key: "key",
|
Key: "key",
|
||||||
SetupKey: "setupkey",
|
|
||||||
IP: net.ParseIP("100.64.0.1"),
|
IP: net.ParseIP("100.64.0.1"),
|
||||||
Status: &nbpeer.PeerStatus{Connected: true},
|
Status: &nbpeer.PeerStatus{Connected: true},
|
||||||
Name: "PeerName",
|
Name: "PeerName",
|
||||||
|
|||||||
@@ -6,10 +6,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
||||||
"github.com/rs/xid"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
isUpdate := policyID != ""
|
policy := &server.Policy{
|
||||||
|
|
||||||
if policyID == "" {
|
|
||||||
policyID = xid.New().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
policy := server.Policy{
|
|
||||||
ID: policyID,
|
ID: policyID,
|
||||||
|
AccountID: accountID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Enabled: req.Enabled,
|
Enabled: req.Enabled,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
}
|
}
|
||||||
for _, rule := range req.Rules {
|
for _, rule := range req.Rules {
|
||||||
|
var ruleID string
|
||||||
|
if rule.Id != nil {
|
||||||
|
ruleID = *rule.Id
|
||||||
|
}
|
||||||
|
|
||||||
pr := server.PolicyRule{
|
pr := server.PolicyRule{
|
||||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
ID: ruleID,
|
||||||
|
PolicyID: policyID,
|
||||||
Name: rule.Name,
|
Name: rule.Name,
|
||||||
Destinations: rule.Destinations,
|
Destinations: rule.Destinations,
|
||||||
Sources: rule.Sources,
|
Sources: rule.Sources,
|
||||||
@@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
policy.SourcePostureChecks = *req.SourcePostureChecks
|
policy.SourcePostureChecks = *req.SourcePostureChecks
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
|
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
|
||||||
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toPolicyResponse(allGroups, &policy)
|
resp := toPolicyResponse(allGroups, policy)
|
||||||
if len(resp.Rules) == 0 {
|
if len(resp.Rules) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
|||||||
}
|
}
|
||||||
return policy, nil
|
return policy, nil
|
||||||
},
|
},
|
||||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
|
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
|
||||||
if !strings.HasPrefix(policy.ID, "id-") {
|
if !strings.HasPrefix(policy.ID, "id-") {
|
||||||
policy.ID = "id-was-set"
|
policy.ID = "id-was-set"
|
||||||
policy.Rules[0].ID = "id-was-set"
|
policy.Rules[0].ID = "id-was-set"
|
||||||
}
|
}
|
||||||
return nil
|
return policy, nil
|
||||||
},
|
},
|
||||||
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||||
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||||
|
|||||||
@@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
|
||||||
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
|||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
},
|
},
|
||||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||||
postureChecks.ID = "postureCheck"
|
postureChecks.ID = "postureCheck"
|
||||||
testPostureChecks[postureChecks.ID] = postureChecks
|
testPostureChecks[postureChecks.ID] = postureChecks
|
||||||
|
|
||||||
if err := postureChecks.Validate(); err != nil {
|
if err := postureChecks.Validate(); err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return postureChecks, nil
|
||||||
},
|
},
|
||||||
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
||||||
_, ok := testPostureChecks[postureChecksID]
|
_, ok := testPostureChecks[postureChecksID]
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.Peer == nil && req.PeerGroups == nil {
|
if req.Peer == nil && req.PeerGroups == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
|
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Peer != nil && req.PeerGroups != nil {
|
if req.Peer != nil && req.PeerGroups != nil {
|
||||||
|
|||||||
@@ -61,10 +61,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
|
|
||||||
expiresIn := time.Duration(req.ExpiresIn) * time.Second
|
expiresIn := time.Duration(req.ExpiresIn) * time.Second
|
||||||
|
|
||||||
day := time.Hour * 24
|
if expiresIn < 0 {
|
||||||
year := day * 365
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn can not be in the past"), w)
|
||||||
if expiresIn < day || expiresIn > year {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,6 +74,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
if req.Ephemeral != nil {
|
if req.Ephemeral != nil {
|
||||||
ephemeral = *req.Ephemeral
|
ephemeral = *req.Ephemeral
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||||
req.AutoGroups, req.UsageLimit, userID, ephemeral)
|
req.AutoGroups, req.UsageLimit, userID, ephemeral)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -83,7 +82,11 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeSuccess(r.Context(), w, setupKey)
|
apiSetupKeys := toResponseBody(setupKey)
|
||||||
|
// for the creation we need to send the plain key
|
||||||
|
apiSetupKeys.Key = setupKey.Key
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSetupKey is a GET request to get a SetupKey by ID
|
// GetSetupKey is a GET request to get a SetupKey by ID
|
||||||
@@ -98,7 +101,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
keyID := vars["keyId"]
|
keyID := vars["keyId"]
|
||||||
if len(keyID) == 0 {
|
if len(keyID) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,7 +126,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
keyID := vars["keyId"]
|
keyID := vars["keyId"]
|
||||||
if len(keyID) == 0 {
|
if len(keyID) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Name == "" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.AutoGroups == nil {
|
if req.AutoGroups == nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||||
return
|
return
|
||||||
@@ -147,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
newKey := &server.SetupKey{}
|
newKey := &server.SetupKey{}
|
||||||
newKey.AutoGroups = req.AutoGroups
|
newKey.AutoGroups = req.AutoGroups
|
||||||
newKey.Revoked = req.Revoked
|
newKey.Revoked = req.Revoked
|
||||||
newKey.Name = req.Name
|
|
||||||
newKey.Id = keyID
|
newKey.Id = keyID
|
||||||
|
|
||||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||||
@@ -181,6 +178,30 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
|
|||||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
keyID := vars["keyId"]
|
||||||
|
if len(keyID) == 0 {
|
||||||
|
util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||||
|
}
|
||||||
|
|
||||||
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
|
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
@@ -206,7 +227,7 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey {
|
|||||||
|
|
||||||
return &api.SetupKey{
|
return &api.SetupKey{
|
||||||
Id: key.Id,
|
Id: key.Id,
|
||||||
Key: key.Key,
|
Key: key.KeySecret,
|
||||||
Name: key.Name,
|
Name: key.Name,
|
||||||
Expires: key.ExpiresAt,
|
Expires: key.ExpiresAt,
|
||||||
Type: string(key.Type),
|
Type: string(key.Type),
|
||||||
|
|||||||
@@ -67,6 +67,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
|||||||
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
|
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
|
||||||
return []*server.SetupKey{defaultKey}, nil
|
return []*server.SetupKey{defaultKey}, nil
|
||||||
},
|
},
|
||||||
|
|
||||||
|
DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error {
|
||||||
|
if keyID == defaultKey.Id {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return status.Errorf(status.NotFound, "key %s not found", keyID)
|
||||||
|
},
|
||||||
},
|
},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
@@ -81,18 +88,21 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupKeysHandlers(t *testing.T) {
|
func TestSetupKeysHandlers(t *testing.T) {
|
||||||
defaultSetupKey := server.GenerateDefaultSetupKey()
|
defaultSetupKey, _ := server.GenerateDefaultSetupKey()
|
||||||
defaultSetupKey.Id = existingSetupKeyID
|
defaultSetupKey.Id = existingSetupKeyID
|
||||||
|
|
||||||
adminUser := server.NewAdminUser("test_user")
|
adminUser := server.NewAdminUser("test_user")
|
||||||
|
|
||||||
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
||||||
server.SetupKeyUnlimitedUsage, true)
|
server.SetupKeyUnlimitedUsage, true)
|
||||||
|
newSetupKey.Key = plainKey
|
||||||
updatedDefaultSetupKey := defaultSetupKey.Copy()
|
updatedDefaultSetupKey := defaultSetupKey.Copy()
|
||||||
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
|
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
|
||||||
updatedDefaultSetupKey.Name = updatedSetupKeyName
|
updatedDefaultSetupKey.Name = updatedSetupKeyName
|
||||||
updatedDefaultSetupKey.Revoked = true
|
updatedDefaultSetupKey.Revoked = true
|
||||||
|
|
||||||
|
expectedNewKey := toResponseBody(newSetupKey)
|
||||||
|
expectedNewKey.Key = plainKey
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
requestType string
|
requestType string
|
||||||
@@ -134,7 +144,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400, \"ephemeral\":true}", newSetupKey.Name, newSetupKey.Type))),
|
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400, \"ephemeral\":true}", newSetupKey.Name, newSetupKey.Type))),
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
expectedSetupKey: toResponseBody(newSetupKey),
|
expectedSetupKey: expectedNewKey,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Update Setup Key",
|
name: "Update Setup Key",
|
||||||
@@ -150,6 +160,14 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
expectedSetupKey: toResponseBody(updatedDefaultSetupKey),
|
expectedSetupKey: toResponseBody(updatedDefaultSetupKey),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Delete Setup Key",
|
||||||
|
requestType: http.MethodDelete,
|
||||||
|
requestPath: "/api/setup-keys/" + defaultSetupKey.Id,
|
||||||
|
requestBody: bytes.NewBuffer([]byte("")),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
|
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
|
||||||
@@ -164,6 +182,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS")
|
router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS")
|
||||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS")
|
router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
|
|||||||
@@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
|||||||
return am.Store.SaveAccount(ctx, a)
|
return am.Store.SaveAccount(ctx, a)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
|
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||||
if len(groups) == 0 {
|
if len(groupIDs) == 0 {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
accountsGroups, err := am.ListGroups(ctx, accountId)
|
|
||||||
if err != nil {
|
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
return false, err
|
for _, groupID := range groupIDs {
|
||||||
}
|
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||||
for _, group := range groups {
|
if err != nil {
|
||||||
var found bool
|
return err
|
||||||
for _, accountGroup := range accountsGroups {
|
|
||||||
if accountGroup.ID == group {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
return nil
|
||||||
return false, nil
|
})
|
||||||
}
|
if err != nil {
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
// IntegratedValidator interface exists to avoid the circle dependencies
|
// IntegratedValidator interface exists to avoid the circle dependencies
|
||||||
type IntegratedValidator interface {
|
type IntegratedValidator interface {
|
||||||
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
|
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
|
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
|
||||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
|
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
|
||||||
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
|
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
|
||||||
|
|||||||
@@ -453,8 +453,8 @@ func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtr
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||||
return update, nil
|
return update, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
peersSSHEnabled++
|
peersSSHEnabled++
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.SetupKey == "" {
|
if peer.UserID != "" {
|
||||||
userPeers++
|
userPeers++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package migration
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
b64 "encoding/base64"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -205,3 +208,90 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) error {
|
||||||
|
oldColumnName := "key"
|
||||||
|
newColumnName := "key_secret"
|
||||||
|
|
||||||
|
var model T
|
||||||
|
|
||||||
|
if !db.Migrator().HasTable(&model) {
|
||||||
|
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := &gorm.Statement{DB: db}
|
||||||
|
err := stmt.Parse(&model)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse model: %w", err)
|
||||||
|
}
|
||||||
|
tableName := stmt.Schema.Table
|
||||||
|
|
||||||
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if !tx.Migrator().HasColumn(&model, newColumnName) {
|
||||||
|
log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", newColumnName, tableName)
|
||||||
|
if err := tx.Migrator().AddColumn(&model, newColumnName); err != nil {
|
||||||
|
return fmt.Errorf("add column %s: %w", newColumnName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var rows []map[string]any
|
||||||
|
if err := tx.Table(tableName).
|
||||||
|
Select("id", oldColumnName, newColumnName).
|
||||||
|
Where(newColumnName + " IS NULL OR " + newColumnName + " = ''").
|
||||||
|
Where("SUBSTR(" + oldColumnName + ", 9, 1) = '-'").
|
||||||
|
Find(&rows).Error; err != nil {
|
||||||
|
return fmt.Errorf("find rows with empty secret key and matching pattern: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rows) == 0 {
|
||||||
|
log.WithContext(ctx).Infof("No plain setup keys found in table %s, no migration needed", tableName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, row := range rows {
|
||||||
|
var plainKey string
|
||||||
|
if columnValue := row[oldColumnName]; columnValue != nil {
|
||||||
|
value, ok := columnValue.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("type assertion failed")
|
||||||
|
}
|
||||||
|
plainKey = value
|
||||||
|
}
|
||||||
|
|
||||||
|
secretKey := hiddenKey(plainKey, 4)
|
||||||
|
|
||||||
|
hashedKey := sha256.Sum256([]byte(plainKey))
|
||||||
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||||
|
|
||||||
|
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, secretKey).Error; err != nil {
|
||||||
|
return fmt.Errorf("update row with secret key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(oldColumnName, encodedHashedKey).Error; err != nil {
|
||||||
|
return fmt.Errorf("update row with hashed key: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Migration of plain setup key to hashed setup key completed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hiddenKey returns the Key value hidden with "*" and a 5 character prefix.
|
||||||
|
// E.g., "831F6*******************************"
|
||||||
|
func hiddenKey(key string, length int) string {
|
||||||
|
prefix := key[0:5]
|
||||||
|
if length > utf8.RuneCountInString(key) {
|
||||||
|
length = utf8.RuneCountInString(key) - len(prefix)
|
||||||
|
}
|
||||||
|
return prefix + strings.Repeat("*", length)
|
||||||
|
}
|
||||||
|
|||||||
@@ -160,3 +160,72 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
|||||||
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
|
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
|
||||||
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
|
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
|
||||||
|
db := setupDatabase(t)
|
||||||
|
|
||||||
|
err := db.AutoMigrate(&server.SetupKey{})
|
||||||
|
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
|
err = db.Save(&server.SetupKey{
|
||||||
|
Id: "1",
|
||||||
|
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
|
||||||
|
}).Error
|
||||||
|
require.NoError(t, err, "Failed to insert setup key")
|
||||||
|
|
||||||
|
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
|
||||||
|
require.NoError(t, err, "Migration should not fail to migrate setup key")
|
||||||
|
|
||||||
|
var key server.SetupKey
|
||||||
|
err = db.Model(&server.SetupKey{}).First(&key).Error
|
||||||
|
assert.NoError(t, err, "Failed to fetch setup key")
|
||||||
|
|
||||||
|
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
|
||||||
|
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) {
|
||||||
|
db := setupDatabase(t)
|
||||||
|
|
||||||
|
err := db.AutoMigrate(&server.SetupKey{})
|
||||||
|
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
|
err = db.Save(&server.SetupKey{
|
||||||
|
Id: "1",
|
||||||
|
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||||
|
KeySecret: "EEFDA****",
|
||||||
|
}).Error
|
||||||
|
require.NoError(t, err, "Failed to insert setup key")
|
||||||
|
|
||||||
|
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
|
||||||
|
require.NoError(t, err, "Migration should not fail to migrate setup key")
|
||||||
|
|
||||||
|
var key server.SetupKey
|
||||||
|
err = db.Model(&server.SetupKey{}).First(&key).Error
|
||||||
|
assert.NoError(t, err, "Failed to fetch setup key")
|
||||||
|
|
||||||
|
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
|
||||||
|
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) {
|
||||||
|
db := setupDatabase(t)
|
||||||
|
|
||||||
|
err := db.AutoMigrate(&server.SetupKey{})
|
||||||
|
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||||
|
|
||||||
|
err = db.Save(&server.SetupKey{
|
||||||
|
Id: "1",
|
||||||
|
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
|
||||||
|
}).Error
|
||||||
|
require.NoError(t, err, "Failed to insert setup key")
|
||||||
|
|
||||||
|
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
|
||||||
|
require.NoError(t, err, "Migration should not fail to migrate setup key")
|
||||||
|
|
||||||
|
var key server.SetupKey
|
||||||
|
err = db.Model(&server.SetupKey{}).First(&key).Error
|
||||||
|
assert.NoError(t, err, "Failed to fetch setup key")
|
||||||
|
|
||||||
|
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")
|
||||||
|
}
|
||||||
|
|||||||
@@ -45,12 +45,11 @@ type MockAccountManager struct {
|
|||||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
|
||||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
|
||||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||||
@@ -97,7 +96,7 @@ type MockAccountManager struct {
|
|||||||
HasConnectedChannelFunc func(peerID string) bool
|
HasConnectedChannelFunc func(peerID string) bool
|
||||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManagerFunc func() idp.Manager
|
GetIdpManagerFunc func() idp.Manager
|
||||||
@@ -109,6 +108,14 @@ type MockAccountManager struct {
|
|||||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
|
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
|
||||||
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
|
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
|
||||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
|
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
|
||||||
|
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||||
|
if am.DeleteSetupKeyFunc != nil {
|
||||||
|
return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID)
|
||||||
|
}
|
||||||
|
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||||
@@ -346,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
|
|||||||
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListGroups mock implementation of ListGroups from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
|
|
||||||
if am.ListGroupsFunc != nil {
|
|
||||||
return am.ListGroupsFunc(ctx, accountID)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
|
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||||
if am.GroupAddPeerFunc != nil {
|
if am.GroupAddPeerFunc != nil {
|
||||||
@@ -387,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
|
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
|
||||||
if am.SavePolicyFunc != nil {
|
if am.SavePolicyFunc != nil {
|
||||||
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
|
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
|
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
|
||||||
@@ -731,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
||||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||||
if am.SavePostureChecksFunc != nil {
|
if am.SavePostureChecksFunc != nil {
|
||||||
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
|
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user