Compare commits

...

47 Commits

Author SHA1 Message Date
Pascal Fischer
261d1e094a add debug logging to getAccount 2025-07-04 17:46:12 +02:00
Pascal Fischer
bcab5cbbee fix migration check 2025-07-04 15:23:03 +02:00
Pascal Fischer
3931958499 disable avoid adding networks multiple times 2025-07-04 14:19:13 +02:00
Pascal Fischer
914e58ac75 Merge branch 'feature/net-266-groups-migration' into feature/net-297-network-migration 2025-07-04 14:18:18 +02:00
Pascal Fischer
7baeea3d9d disable avoid adding group peers multiple times 2025-07-04 14:17:34 +02:00
Pascal Fischer
64e618f1ad disable migration cleanup 2025-07-04 13:14:47 +02:00
Pascal Fischer
5802dcaf80 Merge branch 'feature/net-266-groups-migration' into feature/net-297-network-migration
# Conflicts:
#	management/server/migration/migration.go
#	management/server/peer_test.go
2025-07-04 13:09:01 +02:00
Pascal Fischer
b329397d06 Revert "use FullSaveAssociations for single group save"
This reverts commit 76b180a741.
2025-07-04 10:16:41 +02:00
Pascal Fischer
76b180a741 use FullSaveAssociations for single group save 2025-07-04 00:00:40 +02:00
Pascal Fischer
8db9065ed9 run account update buffer in goroutine 2025-07-03 20:36:35 +02:00
Pascal Fischer
1f79fc0728 ignore associations on create 2025-07-03 19:58:11 +02:00
Pascal Fischer
bdd0b1cf02 update get group by name 2025-07-03 19:24:12 +02:00
Pascal Fischer
e6ac248aee load peers on retrival by name 2025-07-03 19:11:18 +02:00
Pascal Fischer
376394f7f9 fix large batch 2025-07-03 19:00:15 +02:00
Pascal Fischer
542dbdb41c fix group save with removal of all peers 2025-07-03 18:50:17 +02:00
Pascal Fischer
982b9604ee cleanup 2025-07-03 18:41:49 +02:00
Pascal Fischer
f2990e2fbc cleanup 2025-07-03 18:10:00 +02:00
Pascal Fischer
dfb47d5545 fix tests 2025-07-03 17:20:21 +02:00
Pascal Fischer
8e0b8f20a2 fix tests 2025-07-03 16:56:47 +02:00
Pascal Fischer
8a42528664 change getGroupsByPeers 2025-07-03 16:24:00 +02:00
Pascal Fischer
a8cba921e1 fix tests and group copy 2025-07-03 16:04:33 +02:00
Pascal Fischer
fee36b0663 fix peer add if group not existing 2025-07-03 15:40:28 +02:00
Pascal Fischer
dfad334780 fix peer add if group not existing 2025-07-03 15:30:35 +02:00
Pascal Fischer
d25da87957 don't fail on conflict 2025-07-03 15:15:48 +02:00
Pascal Fischer
13213d954d fix save account 2025-07-03 14:49:12 +02:00
Pascal Fischer
6fb61c7cf5 remove group peers from extended store 2025-07-03 13:45:59 +02:00
Pascal Fischer
459db2ba4f fix log in test 2025-07-03 13:41:36 +02:00
Pascal Fischer
e78b7dd058 fis association replace 2025-07-03 13:35:18 +02:00
Pascal Fischer
7132642e4c cleanup tests 2025-07-03 13:23:14 +02:00
Pascal Fischer
22a944b157 cleanup tests 2025-07-03 13:15:56 +02:00
Pascal Fischer
005937ae77 Update management/server/migration/migration.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-03 12:56:39 +02:00
Pascal Fischer
5fab2d019a set lock level on update 2025-07-03 12:55:11 +02:00
Pascal Fischer
36155f8de1 disable delete on migration 2025-07-03 12:51:55 +02:00
Pascal Fischer
d06831dd2f add network migration 2025-07-03 11:20:16 +02:00
Pascal Fischer
e23282b92c add groups migration 2025-07-02 19:09:42 +02:00
Maycon Santos
57961afe95 [doc] Add forum link (#4093)
* Add forum link

* Add forum link
2025-07-02 18:40:07 +02:00
Pascal Fischer
22678bce7f [management] add uniqueness constraint for peer ip and label and optimize generation (#4042) 2025-07-02 18:13:10 +02:00
Maycon Santos
6c633497bc [management] fix network update test for delete policy (#4086)
when adding a peer we calculate the network map an account using backpressure functions and some updates might arrive around the time we are deleting a policy.

This change ensures we wait enough time for the updates from add peer to be sent and read before continuing with the test logic
2025-07-02 12:25:31 +02:00
Carlos Hernandez
6922826919 [client] Support fullstatus without probes (#4052) 2025-07-02 10:42:47 +02:00
Maycon Santos
56a1a75e3f [client] Support random wireguard port on client (#4085)
Adds support for using a random available WireGuard port when the user specifies port `0`.

- Updates `freePort` logic to bind to the requested port (including `0`) without falling back to the default.
- Removes default port assignment in the configuration path, allowing `0` to propagate.
- Adjusts tests to handle dynamically assigned ports when using `0`.
2025-07-02 09:01:02 +02:00
Ali Amer
d9402168ad [management] Add option to disable default all-to-all policy (#3970)
This PR introduces a new configuration option `DisableDefaultPolicy` that prevents the creation of the default all-to-all policy when new accounts are created. This is useful for automation scenarios where explicit policies are preferred.
### Key Changes:
- Added DisableDefaultPolicy flag to the management server config
- Modified account creation logic to respect this flag
- Updated all test cases to explicitly pass the flag (defaulting to false to maintain backward compatibility)
- Propagated the flag through the account manager initialization chain

### Testing:

- Verified default behavior remains unchanged when flag is false
- Confirmed no default policy is created when flag is true
- All existing tests pass with the new parameter
2025-07-02 02:41:59 +02:00
Krzysztof Nazarewski (kdn)
dbdef04b9e [misc] getting-started-with-zitadel.sh: drop unnecessary port 8080 (#4075) 2025-07-02 02:35:13 +02:00
Maycon Santos
29cbfe8467 [misc] update sign pipeline version to v0.0.20 (#4082) 2025-07-01 16:23:31 +02:00
Maycon Santos
6ce8643368 [client] Run login popup on goroutine (#4080) 2025-07-01 13:45:55 +02:00
Krzysztof Nazarewski (kdn)
07d1ad35fc [misc] start the service after installation on arch linux (#4071) 2025-06-30 12:02:03 +02:00
Krzysztof Nazarewski (kdn)
ef6cd36f1a [misc] fix arch install.sh error with empty temporary dependencies
handle empty var before calling removal command
2025-06-30 11:59:35 +02:00
Krzysztof Nazarewski (kdn)
c1c71b6d39 [client] improve adding route log message (#4034)
from:
  Adding route to 1.2.3.4/32 via invalid IP @ 10 (wt0)
to:
  Adding route to 1.2.3.4/32 via no-ip @ 10 (wt0)
2025-06-30 11:57:42 +02:00
54 changed files with 2362 additions and 2491 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.19"
SIGN_PIPE_VER: "v0.0.20"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -134,6 +134,7 @@ jobs:
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
run: |
set -x
@@ -180,6 +181,7 @@ jobs:
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
- name: Install modules
run: go mod tidy

View File

@@ -14,6 +14,9 @@
<br>
<a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
@@ -29,13 +32,13 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
<br/>
</strong>
<br>
<a href="https://github.com/netbirdio/kubernetes-operator">
New: NetBird Kubernetes Operator
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>
</p>

View File

@@ -120,7 +120,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}

View File

@@ -103,7 +103,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}

View File

@@ -319,10 +319,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
*input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort
updated = true
} else if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
log.Infof("using default Wireguard port %d", config.WgPort)
updated = true
}
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {

View File

@@ -17,7 +17,6 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
@@ -526,17 +525,13 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
func freePort(initPort int) (int, error) {
addr := net.UDPAddr{}
if initPort == 0 {
initPort = iface.DefaultWgPort
}
addr.Port = initPort
addr := net.UDPAddr{Port: initPort}
conn, err := net.ListenUDP("udp", &addr)
if err == nil {
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
closeConnWithLog(conn)
return initPort, nil
return returnPort, nil
}
// if the port is already in use, ask the system for a free port

View File

@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
shouldMatch bool
}{
{
name: "not provided, fallback to default",
name: "when port is 0 use random port",
port: 0,
want: 51820,
shouldMatch: true,
want: 0,
shouldMatch: false,
},
{
name: "provided and available",
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
shouldMatch: false,
},
}
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
if err != nil {
t.Errorf("freePort error = %v", err)
}
@@ -39,6 +39,14 @@ func Test_freePort(t *testing.T) {
_ = c1.Close()
}(c1)
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
tests[1].port++
tests[1].want++
}
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -1476,7 +1476,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
permissionsManager := permissions.NewManager(store)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}

View File

@@ -28,7 +28,10 @@ func (n Nexthop) String() string {
if n.Intf == nil {
return n.IP.String()
}
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
if n.IP.IsValid() {
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
}
return fmt.Sprintf("no-ip @ %d (%s)", n.Intf.Index, n.Intf.Name)
}
type wgIface interface {

File diff suppressed because it is too large Load Diff

View File

@@ -158,6 +158,7 @@ message UpResponse {}
message StatusRequest{
bool getFullPeerStatus = 1;
bool shouldRunProbes = 2;
}
message StatusResponse{

View File

@@ -1,4 +1,8 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.29.3
// source: daemon.proto
package proto
@@ -11,8 +15,31 @@ import (
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
DaemonService_Login_FullMethodName = "/daemon.DaemonService/Login"
DaemonService_WaitSSOLogin_FullMethodName = "/daemon.DaemonService/WaitSSOLogin"
DaemonService_Up_FullMethodName = "/daemon.DaemonService/Up"
DaemonService_Status_FullMethodName = "/daemon.DaemonService/Status"
DaemonService_Down_FullMethodName = "/daemon.DaemonService/Down"
DaemonService_GetConfig_FullMethodName = "/daemon.DaemonService/GetConfig"
DaemonService_ListNetworks_FullMethodName = "/daemon.DaemonService/ListNetworks"
DaemonService_SelectNetworks_FullMethodName = "/daemon.DaemonService/SelectNetworks"
DaemonService_DeselectNetworks_FullMethodName = "/daemon.DaemonService/DeselectNetworks"
DaemonService_ForwardingRules_FullMethodName = "/daemon.DaemonService/ForwardingRules"
DaemonService_DebugBundle_FullMethodName = "/daemon.DaemonService/DebugBundle"
DaemonService_GetLogLevel_FullMethodName = "/daemon.DaemonService/GetLogLevel"
DaemonService_SetLogLevel_FullMethodName = "/daemon.DaemonService/SetLogLevel"
DaemonService_ListStates_FullMethodName = "/daemon.DaemonService/ListStates"
DaemonService_CleanState_FullMethodName = "/daemon.DaemonService/CleanState"
DaemonService_DeleteState_FullMethodName = "/daemon.DaemonService/DeleteState"
DaemonService_SetNetworkMapPersistence_FullMethodName = "/daemon.DaemonService/SetNetworkMapPersistence"
DaemonService_TracePacket_FullMethodName = "/daemon.DaemonService/TracePacket"
DaemonService_SubscribeEvents_FullMethodName = "/daemon.DaemonService/SubscribeEvents"
DaemonService_GetEvents_FullMethodName = "/daemon.DaemonService/GetEvents"
)
// DaemonServiceClient is the client API for DaemonService service.
//
@@ -53,7 +80,7 @@ type DaemonServiceClient interface {
// SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error)
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error)
GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error)
}
@@ -66,8 +93,9 @@ func NewDaemonServiceClient(cc grpc.ClientConnInterface) DaemonServiceClient {
}
func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(LoginResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/Login", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_Login_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -75,8 +103,9 @@ func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts
}
func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLoginRequest, opts ...grpc.CallOption) (*WaitSSOLoginResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(WaitSSOLoginResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitSSOLogin", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_WaitSSOLogin_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -84,8 +113,9 @@ func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLogin
}
func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(UpResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/Up", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_Up_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -93,8 +123,9 @@ func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grp
}
func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(StatusResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/Status", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_Status_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -102,8 +133,9 @@ func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opt
}
func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(DownResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/Down", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_Down_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -111,8 +143,9 @@ func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ..
}
func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetConfigResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetConfig", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_GetConfig_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -120,8 +153,9 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques
}
func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ListNetworksResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_ListNetworks_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -129,8 +163,9 @@ func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworks
}
func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(SelectNetworksResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_SelectNetworks_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -138,8 +173,9 @@ func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetw
}
func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(SelectNetworksResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_DeselectNetworks_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -147,8 +183,9 @@ func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNe
}
func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ForwardingRulesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ForwardingRules", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_ForwardingRules_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -156,8 +193,9 @@ func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequ
}
func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(DebugBundleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_DebugBundle_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -165,8 +203,9 @@ func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRe
}
func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetLogLevelResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetLogLevel", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_GetLogLevel_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -174,8 +213,9 @@ func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRe
}
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(SetLogLevelResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_SetLogLevel_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -183,8 +223,9 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe
}
func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ListStatesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_ListStates_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -192,8 +233,9 @@ func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequ
}
func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(CleanStateResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_CleanState_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -201,8 +243,9 @@ func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequ
}
func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(DeleteStateResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_DeleteState_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -210,8 +253,9 @@ func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRe
}
func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(SetNetworkMapPersistenceResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_SetNetworkMapPersistence_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -219,20 +263,22 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *
}
func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(TracePacketResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_TracePacket_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) {
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...)
func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_SubscribeEvents_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &daemonServiceSubscribeEventsClient{stream}
x := &grpc.GenericClientStream[SubscribeRequest, SystemEvent]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
@@ -242,26 +288,13 @@ func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *Subscribe
return x, nil
}
type DaemonService_SubscribeEventsClient interface {
Recv() (*SystemEvent, error)
grpc.ClientStream
}
type daemonServiceSubscribeEventsClient struct {
grpc.ClientStream
}
func (x *daemonServiceSubscribeEventsClient) Recv() (*SystemEvent, error) {
m := new(SystemEvent)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_SubscribeEventsClient = grpc.ServerStreamingClient[SystemEvent]
func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetEventsResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetEvents", in, out, opts...)
err := c.cc.Invoke(ctx, DaemonService_GetEvents_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
@@ -270,7 +303,7 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
// for forward compatibility.
type DaemonServiceServer interface {
// Login uses setup key to prepare configuration for the daemon.
Login(context.Context, *LoginRequest) (*LoginResponse, error)
@@ -307,14 +340,17 @@ type DaemonServiceServer interface {
// SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error
SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error
GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
// UnimplementedDaemonServiceServer must be embedded to have forward compatible implementations.
type UnimplementedDaemonServiceServer struct {
}
// UnimplementedDaemonServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedDaemonServiceServer struct{}
func (UnimplementedDaemonServiceServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Login not implemented")
@@ -370,13 +406,14 @@ func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context
func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented")
}
func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error {
func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error {
return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented")
}
func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to DaemonServiceServer will
@@ -386,6 +423,13 @@ type UnsafeDaemonServiceServer interface {
}
func RegisterDaemonServiceServer(s grpc.ServiceRegistrar, srv DaemonServiceServer) {
// If the following call pancis, it indicates UnimplementedDaemonServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&DaemonService_ServiceDesc, srv)
}
@@ -399,7 +443,7 @@ func _DaemonService_Login_Handler(srv interface{}, ctx context.Context, dec func
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/Login",
FullMethod: DaemonService_Login_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).Login(ctx, req.(*LoginRequest))
@@ -417,7 +461,7 @@ func _DaemonService_WaitSSOLogin_Handler(srv interface{}, ctx context.Context, d
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/WaitSSOLogin",
FullMethod: DaemonService_WaitSSOLogin_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).WaitSSOLogin(ctx, req.(*WaitSSOLoginRequest))
@@ -435,7 +479,7 @@ func _DaemonService_Up_Handler(srv interface{}, ctx context.Context, dec func(in
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/Up",
FullMethod: DaemonService_Up_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).Up(ctx, req.(*UpRequest))
@@ -453,7 +497,7 @@ func _DaemonService_Status_Handler(srv interface{}, ctx context.Context, dec fun
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/Status",
FullMethod: DaemonService_Status_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).Status(ctx, req.(*StatusRequest))
@@ -471,7 +515,7 @@ func _DaemonService_Down_Handler(srv interface{}, ctx context.Context, dec func(
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/Down",
FullMethod: DaemonService_Down_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).Down(ctx, req.(*DownRequest))
@@ -489,7 +533,7 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetConfig",
FullMethod: DaemonService_GetConfig_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetConfig(ctx, req.(*GetConfigRequest))
@@ -507,7 +551,7 @@ func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, d
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ListNetworks",
FullMethod: DaemonService_ListNetworks_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest))
@@ -525,7 +569,7 @@ func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context,
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SelectNetworks",
FullMethod: DaemonService_SelectNetworks_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest))
@@ -543,7 +587,7 @@ func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Contex
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DeselectNetworks",
FullMethod: DaemonService_DeselectNetworks_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest))
@@ -561,7 +605,7 @@ func _DaemonService_ForwardingRules_Handler(srv interface{}, ctx context.Context
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ForwardingRules",
FullMethod: DaemonService_ForwardingRules_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ForwardingRules(ctx, req.(*EmptyRequest))
@@ -579,7 +623,7 @@ func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, de
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DebugBundle",
FullMethod: DaemonService_DebugBundle_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest))
@@ -597,7 +641,7 @@ func _DaemonService_GetLogLevel_Handler(srv interface{}, ctx context.Context, de
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetLogLevel",
FullMethod: DaemonService_GetLogLevel_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest))
@@ -615,7 +659,7 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetLogLevel",
FullMethod: DaemonService_SetLogLevel_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest))
@@ -633,7 +677,7 @@ func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ListStates",
FullMethod: DaemonService_ListStates_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest))
@@ -651,7 +695,7 @@ func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/CleanState",
FullMethod: DaemonService_CleanState_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest))
@@ -669,7 +713,7 @@ func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, de
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DeleteState",
FullMethod: DaemonService_DeleteState_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest))
@@ -687,7 +731,7 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence",
FullMethod: DaemonService_SetNetworkMapPersistence_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest))
@@ -705,7 +749,7 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/TracePacket",
FullMethod: DaemonService_TracePacket_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest))
@@ -718,21 +762,11 @@ func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerS
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(DaemonServiceServer).SubscribeEvents(m, &daemonServiceSubscribeEventsServer{stream})
return srv.(DaemonServiceServer).SubscribeEvents(m, &grpc.GenericServerStream[SubscribeRequest, SystemEvent]{ServerStream: stream})
}
type DaemonService_SubscribeEventsServer interface {
Send(*SystemEvent) error
grpc.ServerStream
}
type daemonServiceSubscribeEventsServer struct {
grpc.ServerStream
}
func (x *daemonServiceSubscribeEventsServer) Send(m *SystemEvent) error {
return x.ServerStream.SendMsg(m)
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_SubscribeEventsServer = grpc.ServerStreamingServer[SystemEvent]
func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetEventsRequest)
@@ -744,7 +778,7 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetEvents",
FullMethod: DaemonService_GetEvents_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetEvents(ctx, req.(*GetEventsRequest))

View File

@@ -707,7 +707,9 @@ func (s *Server) Status(
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
if msg.GetFullPeerStatus {
s.runProbes()
if msg.ShouldRunProbes {
s.runProbes()
}
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)

View File

@@ -206,7 +206,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}

View File

@@ -879,7 +879,7 @@ func (s *serviceClient) onUpdateAvailable() {
func (s *serviceClient) onSessionExpire() {
s.sendNotification = true
if s.sendNotification {
s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
go s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
s.sendNotification = false
}
}

View File

@@ -15,6 +15,7 @@ NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAI
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
# Signal
NETBIRD_SIGNAL_PROTOCOL="http"
@@ -139,3 +140,4 @@ export NETBIRD_RELAY_PORT
export NETBIRD_RELAY_ENDPOINT
export NETBIRD_RELAY_AUTH_SECRET
export NETBIRD_RELAY_TAG
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY

View File

@@ -791,7 +791,6 @@ services:
- '443:443'
- '443:443/udp'
- '80:80'
- '8080:8080'
volumes:
- netbird_caddy_data:/data
- ./Caddyfile:/etc/caddy/Caddyfile

View File

@@ -38,6 +38,7 @@
"0.0.0.0/0"
]
},
"DisableDefaultPolicy": $NETBIRD_MGMT_DISABLE_DEFAULT_POLICY,
"Datadir": "",
"DataStoreEncryptionKey": "$NETBIRD_DATASTORE_ENC_KEY",
"StoreConfig": {

View File

@@ -92,7 +92,8 @@ NETBIRD_LETSENCRYPT_EMAIL=""
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
# Disable default all-to-all policy for new accounts
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
# -------------------------------------------
# Relay settings
# -------------------------------------------

View File

@@ -29,3 +29,4 @@ NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
NETBIRD_RELAY_PORT=33445
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY

View File

@@ -100,7 +100,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
Return(true, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}

View File

@@ -215,7 +215,7 @@ var (
peersManager := peers.NewManager(store, permissionsManager)
proxyController := integrations.NewController(store)
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager)
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
}

View File

@@ -102,6 +102,20 @@ type DefaultAccountManager struct {
accountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
disableDefaultPolicy bool
}
func isUniqueConstraintError(err error) bool {
switch {
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
strings.Contains(err.Error(), "Error 1062 (23000)"),
strings.Contains(err.Error(), "UNIQUE constraint failed"):
return true
default:
return false
}
}
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
@@ -170,6 +184,7 @@ func BuildManager(
proxyController port_forwarding.Controller,
settingsManager settings.Manager,
permissionsManager permissions.Manager,
disableDefaultPolicy bool,
) (*DefaultAccountManager, error) {
start := time.Now()
defer func() {
@@ -195,6 +210,7 @@ func BuildManager(
proxyController: proxyController,
settingsManager: settingsManager,
permissionsManager: permissionsManager,
disableDefaultPolicy: disableDefaultPolicy,
}
am.startWarmup(ctx)
@@ -543,7 +559,7 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
continue
case statusErr.Type() == status.NotFound:
newAccount := newAccountWithId(ctx, accountId, userID, domain)
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
default:
@@ -1657,25 +1673,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
return false, nil
}
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
labelMap := ConvertSliceToMap(existingLabels)
newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap)
if err != nil {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
if newLabel == "" {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
return newLabel, nil
}
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {
@@ -1688,10 +1685,10 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
network := types.NewNetwork()
network := types.NewNetwork(accountID)
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[route.ID]*route.Route)
@@ -1731,7 +1728,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
},
}
if err := acc.AddAllGroup(); err != nil {
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
}
return acc
@@ -1795,7 +1792,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
continue
}
network := types.NewNetwork()
network := types.NewNetwork(accountId)
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[route.ID]*route.Route)
@@ -1833,7 +1830,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
},
}
if err := newAccount.AddAllGroup(); err != nil {
if err := newAccount.AddAllGroup(am.disableDefaultPolicy); err != nil {
return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
}

View File

@@ -373,7 +373,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
for _, testCase := range tt {
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
account.UpdateSettings(&testCase.accountSettings)
account.Network = network
account.Peers = testCase.peers
@@ -398,7 +398,7 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io"
userId := "account_creator"
accountID := "account_id"
account := newAccountWithId(context.Background(), accountID, userId, domain)
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
}
@@ -640,7 +640,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
userId := "user-id"
domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain)
_ = newAccountWithId(context.Background(), "", userId, domain, false)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
@@ -793,7 +793,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
}
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
account := newAccountWithId(context.Background(), accountID, userID, domain)
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
@@ -1208,6 +1208,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
// Ensure that we do not receive an update message before the policy is deleted
time.Sleep(time.Second)
select {
case <-updMsg:
t.Logf("received addPeer update message before policy deletion")
default:
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -1232,9 +1240,10 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
AccountID: account.Id,
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
t.Errorf("save group: %v", err)
@@ -1664,9 +1673,10 @@ func TestAccount_Copy(t *testing.T) {
},
Groups: map[string]*types.Group{
"group1": {
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
},
Policies: []*types.Policy{
@@ -2608,6 +2618,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
}
func TestAccount_SetJWTGroups(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -2615,11 +2626,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
account := &types.Account{
Id: "accountID",
Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"},
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
"peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"},
"peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"},
"peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"},
"peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"},
},
Groups: map[string]*types.Group{
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
@@ -2879,7 +2890,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, err
}
@@ -3139,11 +3150,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 7, 20, 10, 80},
{"Small", 50, 5, 7, 20, 5, 80},
{"Medium", 500, 100, 5, 40, 30, 140},
{"Large", 5000, 200, 80, 120, 140, 390},
{"Small single", 50, 10, 7, 20, 10, 80},
{"Medium single", 500, 10, 5, 40, 20, 85},
{"Small single", 50, 10, 7, 20, 6, 80},
{"Medium single", 500, 10, 5, 40, 15, 85},
{"Large 5", 5000, 15, 80, 120, 80, 200},
}
@@ -3335,11 +3346,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId}
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId}
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
require.NoError(t, err)

View File

@@ -217,7 +217,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createDNSStore(t *testing.T) (store.Store, error) {
@@ -267,7 +267,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
domain := "example.com"
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, false)
account.Users[dnsRegularUserID] = &types.User{
Id: dnsRegularUserID,

View File

@@ -127,7 +127,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId(context.Background(), "my account", "", "")
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i)

View File

@@ -265,20 +265,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
@@ -288,7 +278,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.AddPeerToGroup(ctx, peerID, groupID)
})
if err != nil {
return err
@@ -347,20 +337,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
@@ -370,7 +350,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
})
if err != nil {
return err

View File

@@ -2,14 +2,19 @@ package server
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/groups"
@@ -18,8 +23,10 @@ import (
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@@ -369,7 +376,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
Id: "example user",
AutoGroups: []string{groupForUsers.ID},
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
@@ -733,3 +740,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}
})
}
func Test_AddPeerToGroup(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
acc, err := createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.AddPeerToGroup(context.Background(), strconv.Itoa(i), acc.GroupsG[0].ID)
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
}
func Test_AddPeerToAll(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i))
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
}
func Test_AddPeerAndAddToAll(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
peer := &peer2.Peer{
ID: strconv.Itoa(i),
AccountID: accountID,
DNSLabel: "peer" + strconv.Itoa(i),
IP: uint32ToIP(uint32(i)),
}
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
}
err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
}
return nil
})
if err != nil {
t.Errorf("AddPeer failed for peer %d: %v", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
}
func uint32ToIP(n uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, n)
return ip
}
func Test_IncrementNetworkSerial(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get account %s: %v", accountID, err)
}
return nil
})
if err != nil {
t.Errorf("AddPeer failed for peer %d: %v", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial)
}

View File

@@ -69,7 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
handler := initAccountsTestData(t, &types.Account{
Id: accountID,
Domain: "hotmail.com",
Network: types.NewNetwork(),
Network: types.NewNetwork(accountID),
Users: map[string]*types.User{
adminUser.Id: adminUser,
},

View File

@@ -1,5 +1,4 @@
package testing_tools
import (
"bytes"
"context"
@@ -138,7 +137,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}

View File

@@ -444,7 +444,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
permissionsManager := permissions.NewManager(store)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
cleanup()

View File

@@ -211,7 +211,7 @@ func startServer(
port_forwarding.NewControllerMock(),
settingsMockManager,
permissionsManager,
)
false)
if err != nil {
t.Fatalf("failed creating an account manager: %v", err)
}

View File

@@ -10,13 +10,25 @@ import (
"errors"
"fmt"
"net"
"reflect"
"strings"
"unicode/utf8"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type LegacyAccountNetwork struct {
AccountID string `gorm:"column:id"`
Identifier string `gorm:"column:network_identifier"`
Net net.IPNet `gorm:"column:network_net;serializer:json"`
Dns string `gorm:"column:network_dns"`
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64 `gorm:"column:network_serial"`
}
func GetColumnName(db *gorm.DB, column string) string {
if db.Name() == "mysql" {
return fmt.Sprintf("`%s`", column)
@@ -39,6 +51,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
return nil
}
if !db.Migrator().HasColumn(&model, oldColumnName) {
log.WithContext(ctx).Debugf("Column for %T does not exist, no migration needed", oldColumnName)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(model)
if err != nil {
@@ -373,3 +390,188 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
return nil
}
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
var model T
stmt := &gorm.Statement{DB: db}
if err := stmt.Parse(&model); err != nil {
return fmt.Errorf("failed to parse model schema: %w", err)
}
tableName := stmt.Schema.Table
dialect := db.Dialector.Name()
var columnClause string
if dialect == "mysql" {
var withLength []string
for _, col := range columns {
if col == "ip" || col == "dns_label" {
withLength = append(withLength, fmt.Sprintf("%s(64)", col))
} else {
withLength = append(withLength, col)
}
}
columnClause = strings.Join(withLength, ", ")
} else {
columnClause = strings.Join(columns, ", ")
}
createStmt := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, columnClause)
if dialect == "postgres" || dialect == "sqlite" {
createStmt = strings.Replace(createStmt, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
}
log.WithContext(ctx).Infof("executing index creation: %s", createStmt)
if err := db.Exec(createStmt).Error; err != nil {
return fmt.Errorf("failed to create index %s: %w", indexName, err)
}
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
return nil
}
func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(id string, value string) any) error {
var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if !db.Migrator().HasColumn(&model, columnName) {
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName)
return nil
}
if err := db.Transaction(func(tx *gorm.DB) error {
var rows []map[string]any
if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
jsonValue, ok := row[columnName].(string)
if !ok || jsonValue == "" {
continue
}
var data []string
if err := json.Unmarshal([]byte(jsonValue), &data); err != nil {
return fmt.Errorf("unmarshal json: %w", err)
}
for _, value := range data {
if err := tx.Clauses(clause.OnConflict{
DoNothing: true, // this needs to be removed when the cleanup is enabled
}).Create(
mapperFunc(row["id"].(string), value),
).Error; err != nil {
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
}
}
}
// Todo: Enable this after we are sure that every thing works as expected and we do not need to rollback anymore
// if err := tx.Migrator().DropColumn(&model, columnName); err != nil {
// return fmt.Errorf("drop column %s: %w", columnName, err)
// }
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
return nil
}
func MigrateEmbeddedToTable[T any, S any, U any](ctx context.Context, db *gorm.DB, pkey string, mapperFunc func(obj S) *U) error {
var model T
log.WithContext(ctx).Debugf("Migrating embedded fields from %T to separate table", model)
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if err := db.Transaction(func(tx *gorm.DB) error {
var legacyRows []S
if err := tx.Table(tableName).Find(&legacyRows).Error; err != nil {
log.WithContext(ctx).Errorf("Failed to read legacy accounts: %v", err)
return fmt.Errorf("failed to read legacy accounts: %w", err)
}
for _, row := range legacyRows {
if err := tx.Clauses(clause.OnConflict{
DoNothing: true, // this needs to be removed when the cleanup is enabled
}).Create(
mapperFunc(row),
).Error; err != nil {
return fmt.Errorf("failed to insert id %v: %w", row, err)
}
}
// cols, err := getColumnNamesFromStruct(new(S))
// if err != nil {
// return fmt.Errorf("failed to extract column names: %w", err)
// }
// for _, col := range cols {
// if col == pkey {
// continue
// }
// if err := tx.Migrator().DropColumn(&model, col); err != nil {
// return fmt.Errorf("failed to drop column %s: %w", col, err)
// }
// }
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Infof("Migration of embedded fields %T from table %s into seperte table completed", new(S), tableName)
return nil
}
func getColumnNamesFromStruct[T any](model T) ([]string, error) {
val := reflect.TypeOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
var cols []string
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
if field.Name == "ID" {
continue // skip primary key
}
tag := field.Tag.Get("gorm")
if tag == "" {
continue
}
// Look for gorm:"column:..."
for _, part := range strings.Split(tag, ";") {
if strings.HasPrefix(part, "column:") {
cols = append(cols, strings.TrimPrefix(part, "column:"))
break
}
}
}
return cols, nil
}

View File

@@ -779,7 +779,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createNSStore(t *testing.T) (store.Store, error) {
@@ -848,7 +848,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
userID := testUserID
domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain)
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup

View File

@@ -15,13 +15,14 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -234,14 +235,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
if peer.Name != update.Name {
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
var newLabel string
newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name)
if err != nil {
return err
}
newLabel, err := types.GetPeerHostLabel(update.Name, existingLabels)
if err != nil {
return err
return fmt.Errorf("failed to get free DNS label: %w", err)
}
peer.Name = update.Name
@@ -363,25 +360,20 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to get peer groups: %w", err)
}
for _, group := range groups {
group.RemovePeer(peerID)
err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
if err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
return fmt.Errorf("failed to remove peer from groups: %w", err)
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
return err
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return err
@@ -463,233 +455,247 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
upperKey := strings.ToUpper(setupKey)
hashedKey := sha256.Sum256([]byte(upperKey))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
var accountID string
var err error
addedByUser := false
if len(userID) > 0 {
addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
} else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
}
if err != nil {
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
}
}()
addedByUser := len(userID) > 0
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key)
_, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
Timestamp: time.Now().UTC(),
AccountID: accountID,
}
var newPeer *nbpeer.Peer
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var setupKeyID string
var setupKeyName string
var ephemeral bool
var groupsToAdd []string
var allowExtraDNSLabels bool
if addedByUser {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID)
if err != nil {
return fmt.Errorf("failed to get user groups: %w", err)
}
groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
} else {
// Validate the setup key
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
if !sk.IsValid() {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
groupsToAdd = sk.AutoGroups
ephemeral = sk.Ephemeral
setupKeyID = sk.Id
setupKeyName = sk.Name
allowExtraDNSLabels = sk.AllowExtraDNSLabels
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
}
}
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
if am.idpManager != nil {
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
var setupKeyID string
var setupKeyName string
var ephemeral bool
var groupsToAdd []string
var allowExtraDNSLabels bool
var accountID string
if addedByUser {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return fmt.Errorf("failed to get free DNS label: %w", err)
return nil, nil, nil, fmt.Errorf("failed to get user groups: %w", err)
}
freeIP, err := getFreeIP(ctx, transaction, accountID)
groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
accountID = user.AccountID
} else {
// Validate the setup key
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
if err != nil {
return fmt.Errorf("failed to get free IP: %w", err)
return nil, nil, nil, fmt.Errorf("failed to get setup key: %w", err)
}
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
return status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
// we will check key twice for early return
if !sk.IsValid() {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: &registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
InactivityExpirationEnabled: addedByUser,
ExtraDNSLabels: peer.ExtraDNSLabels,
AllowExtraDNSLabels: allowExtraDNSLabels,
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
groupsToAdd = sk.AutoGroups
ephemeral = sk.Ephemeral
setupKeyID = sk.Id
setupKeyName = sk.Name
allowExtraDNSLabels = sk.AllowExtraDNSLabels
accountID = sk.AccountID
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
}
}
opEvent.AccountID = accountID
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
if am.idpManager != nil {
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
if err != nil {
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
}
} else {
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
}
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: &registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
InactivityExpirationEnabled: addedByUser,
ExtraDNSLabels: peer.ExtraDNSLabels,
AllowExtraDNSLabels: allowExtraDNSLabels,
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err)
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed getting network: %w", err)
}
maxAttempts := 10
for attempt := 1; attempt <= maxAttempts; attempt++ {
var freeIP net.IP
freeIP, err = types.AllocateRandomPeerIP(network.Net)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err)
}
var freeLabel string
freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
}
newPeer.DNSLabel = freeLabel
newPeer.IP = freeIP
// unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
// defer func() {
// if unlock != nil {
// unlock()
// }
// }()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
if err != nil {
return err
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
}
} else {
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
// we validate at the end to not block the setup key for too long
if !sk.IsValid() {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
}
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})
if err == nil {
// unlock()
// unlock = nil
break
}
if isUniqueConstraintError(err) {
// unlock()
// unlock = nil
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
continue
}
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
}
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
unlock()
unlock = nil
if updateAccountPeers {
am.BufferUpdateAccountPeers(ctx, accountID)
}
am.BufferUpdateAccountPeers(ctx, accountID)
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
}
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) {
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID)
func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) {
ip = ip.To4()
dnsName, err := nbdns.GetParsedDomainLabel(peerHostName)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err)
}
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID)
_, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName)
if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err)
//nolint:nilerr
return dnsName, nil
}
nextIp, err := types.AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
}
return nextIp, nil
return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
@@ -1000,7 +1006,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}()
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -1249,17 +1255,19 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
lock := mu.(*sync.Mutex)
if !lock.TryLock() {
return
}
go func() {
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
lock.Unlock()
am.UpdateAccountPeers(ctx, accountID)
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
lock := mu.(*sync.Mutex)
if !lock.TryLock() {
return
}
go func() {
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
lock.Unlock()
am.UpdateAccountPeers(ctx, accountID)
}()
}()
}
@@ -1477,19 +1485,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
return groupIDs, err
}
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
existingLabels := make(types.LookupMap)
for _, label := range dnsLabels {
existingLabels[label] = struct{}{}
}
return existingLabels, nil
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {

View File

@@ -20,14 +20,14 @@ type Peer struct {
// WireGuard public key
Key string `gorm:"index"`
// IP address of the Peer
IP net.IP `gorm:"serializer:json"`
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name)
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
DNSLabel string // uniqueness index per accountID (check migrations)
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer

View File

@@ -10,7 +10,9 @@ import (
"net/netip"
"os"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
@@ -19,6 +21,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -480,7 +483,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
accountID := "test_account"
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
account.Users[someUser] = &types.User{
Id: someUser,
Role: types.UserRoleUser,
@@ -667,7 +670,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
accountID := "test_account"
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
account.Users[someUser] = &types.User{
Id: someUser,
Role: testCase.role,
@@ -737,7 +740,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
adminUser := "account_creator"
regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
account.Users[regularUser] = &types.User{
Id: regularUser,
Role: types.UserRoleUser,
@@ -1267,7 +1270,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1342,7 +1345,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1391,7 +1394,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
name: "Absent setup key",
existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB",
expectAddPeerError: true,
expectedErrorMsgSubstring: "failed adding new peer: account not found",
expectedErrorMsgSubstring: "failed to get setup key: setup key not found",
},
}
@@ -1456,6 +1459,10 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
engine := os.Getenv("NETBIRD_STORE_ENGINE")
if engine == "sqlite" || engine == "" {
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
}
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
@@ -1477,7 +1484,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1546,7 +1553,7 @@ func Test_LoginPeer(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1761,7 +1768,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -2052,15 +2059,19 @@ func Test_DeletePeer(t *testing.T) {
// account with an admin and a regular user
accountID := "test_account"
adminUser := "account_creator"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
account.Peers = map[string]*nbpeer.Peer{
"peer1": {
ID: "peer1",
AccountID: accountID,
IP: net.IP{1, 1, 1, 1},
DNSLabel: "peer1.test",
},
"peer2": {
ID: "peer2",
AccountID: accountID,
IP: net.IP{2, 2, 2, 2},
DNSLabel: "peer2.test",
},
}
account.Groups = map[string]*types.Group{
@@ -2090,3 +2101,139 @@ func Test_DeletePeer(t *testing.T) {
assert.NotContains(t, group.Peers, "peer1")
}
func Test_IsUniqueConstraintError(t *testing.T) {
tests := []struct {
name string
engine types.Engine
}{
{
name: "PostgreSQL uniqueness error",
engine: types.PostgresStoreEngine,
},
{
name: "MySQL uniqueness error",
engine: types.MysqlStoreEngine,
},
{
name: "SQLite uniqueness error",
engine: types.SqliteStoreEngine,
},
}
peer := &nbpeer.Peer{
ID: "test-peer-id",
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
DNSLabel: "test-peer-dns-label",
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(tt.engine))
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
assert.NoError(t, err)
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
result := isUniqueConstraintError(err)
assert.True(t, result)
})
}
}
func Test_AddPeer(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
t.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "300ms")
t.Setenv("NB_PEER_UPDATE_BUFFER_INTERVAL", "300ms")
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatalf("error creating account: %v", err)
return
}
setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", types.SetupKeyReusable, time.Hour, nil, 10000, userID, false, false)
if err != nil {
t.Fatal("error creating setup key")
return
}
const totalPeers = 300
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
newPeer := &nbpeer.Peer{
AccountID: accountID,
Key: "key" + strconv.Itoa(i),
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"},
}
<-start
_, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer)
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
seenIP := make(map[string]bool)
for _, p := range account.Peers {
ipStr := p.IP.String()
if seenIP[ipStr] {
t.Fatalf("Duplicate IP found in account %s: %s", accountID, ipStr)
}
seenIP[ipStr] = true
}
seenLabel := make(map[string]bool)
for _, p := range account.Peers {
if seenLabel[p.DNSLabel] {
t.Fatalf("Duplicate Label found in account %s: %s", accountID, p.DNSLabel)
}
seenLabel[p.DNSLabel] = true
}
assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes)
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
}

View File

@@ -106,7 +106,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
Role: types.UserRoleUser,
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
account.Users[admin.Id] = admin
account.Users[user.Id] = user

View File

@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createRouterStore(t *testing.T) (store.Store, error) {
@@ -1305,7 +1305,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
accountID := "testingAcc"
domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain)
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err

View File

@@ -156,7 +156,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
allGroup, err := account.GetGroupAll()
if err != nil {
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migratePreAuto from a version that didn't support groups. Error: %v", err)
// if the All group didn't exist we probably don't have routes to update
continue
}

View File

@@ -92,17 +92,20 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
}
if err := migrate(ctx, db); err != nil {
return nil, fmt.Errorf("migrate: %w", err)
if err := migratePreAuto(ctx, db); err != nil {
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.Network{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
)
if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
}
if err := migratePostAuto(ctx, db); err != nil {
return nil, fmt.Errorf("migratePostAuto: %w", err)
}
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
@@ -183,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
generateAccountSQLTypes(account)
for _, group := range account.GroupsG {
group.StoreGroupPeers()
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
@@ -244,7 +251,7 @@ func generateAccountSQLTypes(account *types.Account) {
for id, group := range account.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
account.GroupsG = append(account.GroupsG, group)
}
for id, route := range account.Routes {
@@ -452,19 +459,40 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength,
return nil
}
result := s.db.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Create(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
for _, g := range groups {
g.StoreGroupPeers()
}
return nil
return s.db.Transaction(func(tx *gorm.DB) error {
result := tx.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Create(&groups)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save groups to store")
}
for _, g := range groups {
if len(g.GroupPeers) == 0 {
if err := tx.Where("group_id = ?", g.ID).Delete(&types.GroupPeer{}).Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group peers for group %s: %s", g.ID, err)
return status.Errorf(status.Internal, "failed to delete group peers")
}
} else {
if err := tx.Model(&g).Association("GroupPeers").Replace(g.GroupPeers); err != nil {
return status.Errorf(status.Internal, "failed to save group peers: %s", err)
}
}
}
return nil
})
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -643,7 +671,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
}
var groups []*types.Group
result := tx.Find(&groups, accountIDCondition, accountID)
result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -652,6 +680,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -666,6 +698,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
likePattern := `%"ID":"` + resourceID + `"%`
result := tx.
Preload(clause.Associations).
Where("resources LIKE ?", likePattern).
Find(&groups)
@@ -676,6 +709,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
return nil, result.Error
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -735,8 +772,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
}()
var account types.Account
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
result := s.db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Info)}).Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload("GroupsG.GroupPeers"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, idQueryCondition, accountID)
if result.Error != nil {
@@ -750,7 +788,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*types.PolicyRule
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
err := s.db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Info)}).Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "rule not found")
}
@@ -781,6 +819,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
group.LoadGroupPeers()
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
@@ -967,7 +1006,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string, dnsLabel string) ([]string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
@@ -975,7 +1014,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
var labels []string
result := tx.Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Where("account_id = ? AND dns_label LIKE ?", accountID, dnsLabel+"%").
Pluck("dns_label", &labels)
if result.Error != nil {
@@ -995,14 +1034,14 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
accountNetwork := types.Network{}
if err := tx.Where(accountIDCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
}
return accountNetwork.Network, nil
return &accountNetwork, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
@@ -1184,7 +1223,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
for _, account := range fileStore.GetAllAccounts(ctx) {
_, err = account.GetGroupAll()
if err != nil {
if err := account.AddAllGroup(); err != nil {
if err := account.AddAllGroup(false); err != nil {
return nil, err
}
}
@@ -1254,7 +1293,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(key)
return nil, status.Errorf(status.PreconditionFailed, "setup key not found")
}
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
@@ -1282,55 +1321,74 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&group, "account_id = ? AND name = ?", accountID, "All")
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var groupID string
_ = s.db.Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
Scan(&groupID)
if groupID == "" {
return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID)
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(&types.GroupPeer{
GroupID: groupID,
PeerID: peerID,
}).Error
group.Peers = append(group.Peers, peerID)
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
if err != nil {
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
}
return nil
}
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID)
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, peerID string, groupID string) error {
peer := &types.GroupPeer{
GroupID: groupID,
PeerID: peerID,
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(peer).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer %s to group %s: %v", peerID, groupID, err)
return status.Errorf(status.Internal, "failed to add peer to group")
}
group.Peers = append(group.Peers, peerId)
return nil
}
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group: %s", err)
// RemovePeerFromGroup removes a peer from a group
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
return status.Errorf(status.Internal, "failed to remove peer from group")
}
return nil
}
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
return status.Errorf(status.Internal, "failed to remove peer from all groups")
}
return nil
@@ -1398,12 +1456,19 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
var groups []*types.Group
query := tx.
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
Joins("JOIN group_peers ON group_peers.group_id = groups.id").
Where("group_peers.peer_id = ?", peerId).
Preload(clause.Associations).
Find(&groups)
if query.Error != nil {
return nil, query.Error
}
for _, group := range groups {
group.LoadGroupPeers()
}
return groups, nil
}
@@ -1452,7 +1517,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1580,7 +1645,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength,
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
Model(&types.Network{}).Where(accountIDCondition, accountId).Update("serial", gorm.Expr("serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -1689,7 +1754,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
}
var group *types.Group
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupID)
@@ -1698,15 +1763,14 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
group.LoadGroupPeers()
return group, nil
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var group types.Group
@@ -1714,16 +1778,14 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// we may need to reconsider changing the types.
query := tx.Preload(clause.Associations)
switch s.storeEngine {
case types.PostgresStoreEngine:
query = query.Order("json_array_length(peers::json) DESC")
case types.MysqlStoreEngine:
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
default:
query = query.Order("json_array_length(peers) DESC")
}
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
result := query.
Model(&types.Group{}).
Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id").
Where("groups.account_id = ? AND groups.name = ?", accountID, groupName).
Group("groups.id").
Order("COUNT(group_peers.peer_id) DESC").
Limit(1).
First(&group)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupName)
@@ -1731,6 +1793,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
}
group.LoadGroupPeers()
return &group, nil
}
@@ -1742,7 +1807,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
}
var groups []*types.Group
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
@@ -1750,6 +1815,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
group.LoadGroupPeers()
groupsMap[group.ID] = group
}
@@ -1758,17 +1824,36 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
if group == nil {
return status.Errorf(status.InvalidArgument, "group is nil")
}
group = group.Copy()
group.StoreGroupPeers()
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
return status.Errorf(status.Internal, "failed to save group to store")
}
if len(group.GroupPeers) == 0 {
if err := s.db.Where("group_id = ?", group.ID).Delete(&types.GroupPeer{}).Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group peers for group %s: %s", group.ID, err)
return status.Errorf(status.Internal, "failed to delete group peers")
}
} else {
if err := s.db.Model(&group).Association("GroupPeers").Replace(group.GroupPeers); err != nil {
return status.Errorf(status.Internal, "failed to save group peers: %s", err)
}
}
return nil
}
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
@@ -1785,6 +1870,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
@@ -2546,6 +2632,27 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
return &peer, nil
}
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peerID string
result := tx.Model(&nbpeer.Peer{}).
Select("id").
// Where(" = ?", hostname).
Where("account_id = ? AND dns_label = ?", accountID, hostname).
Limit(1).
Scan(&peerID)
if peerID == "" {
return "", gorm.ErrRecordNotFound
}
return peerID, result.Error
}
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
var count int64
result := s.db.Model(&types.Account{}).

View File

@@ -10,6 +10,7 @@ import (
"net/netip"
"os"
"runtime"
"sort"
"sync"
"testing"
"time"
@@ -630,7 +631,7 @@ func TestMigrate(t *testing.T) {
t.Cleanup(cleanUp)
assert.NoError(t, err)
err = migrate(context.Background(), store.(*SqlStore).db)
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on empty db")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
@@ -685,10 +686,10 @@ func TestMigrate(t *testing.T) {
err = store.(*SqlStore).db.Save(rt).Error
require.NoError(t, err, "Failed to insert Gob data")
err = migrate(context.Background(), store.(*SqlStore).db)
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on gob populated db")
err = migrate(context.Background(), store.(*SqlStore).db)
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on migrated db")
err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
@@ -704,10 +705,10 @@ func TestMigrate(t *testing.T) {
err = store.(*SqlStore).db.Save(nRT).Error
require.NoError(t, err, "Failed to insert json nil slice data")
err = migrate(context.Background(), store.(*SqlStore).db)
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on json nil slice populated db")
err = migrate(context.Background(), store.(*SqlStore).db)
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on migrated db")
}
@@ -950,6 +951,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1",
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
@@ -961,8 +963,9 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
assert.Equal(t, []net.IP{ip1}, takenIPs)
peer2 := &nbpeer.Peer{
ID: "peer2",
ID: "peer1second",
AccountID: existingAccountID,
DNSLabel: "peer1-1",
IP: net.IP{2, 2, 2, 2},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
@@ -972,49 +975,100 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
require.NoError(t, err)
ip2 := net.IP{2, 2, 2, 2}.To16()
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
if err != nil {
return
}
t.Cleanup(cleanup)
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peerHostname := "peer1"
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1",
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
require.NoError(t, err)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
require.NoError(t, err)
assert.Equal(t, []string{"peer1"}, labels)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
peer2 := &nbpeer.Peer{
ID: "peer1second",
AccountID: existingAccountID,
DNSLabel: "peer1-1",
IP: net.IP{2, 2, 2, 2},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
require.NoError(t, err)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
DNSLabel: "peer2.domain.test",
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
expected := []string{"peer1", "peer1-1"}
sort.Strings(expected)
sort.Strings(labels)
assert.Equal(t, expected, labels)
})
}
func Test_AddPeerWithSameDnsLabel(t *testing.T) {
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
require.NoError(t, err)
peer2 := &nbpeer.Peer{
ID: "peer1second",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
require.Error(t, err)
})
}
func Test_AddPeerWithSameIP(t *testing.T) {
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
require.NoError(t, err)
peer2 := &nbpeer.Peer{
ID: "peer1second",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
require.Error(t, err)
})
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
@@ -1286,10 +1340,12 @@ func TestSqlStore_SaveGroup(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group := &types.Group{
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
}
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
require.NoError(t, err)
@@ -1308,16 +1364,19 @@ func TestSqlStore_SaveGroups(t *testing.T) {
groups := []*types.Group{
{
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
{
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{"peer3", "peer4"},
Resources: []types.Resource{},
},
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
@@ -2005,7 +2064,7 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
network := types.NewNetwork()
network := types.NewNetwork(accountID)
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[nbroute.ID]*nbroute.Route)
@@ -2044,7 +2103,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
},
}
if err := acc.AddAllGroup(); err != nil {
if err := acc.AddAllGroup(false); err != nil {
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
}
return acc
@@ -2452,7 +2511,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) {
require.NoError(t, err, "failed to get group")
require.Len(t, group.Peers, 0, "group should have 0 peers")
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID)
err = store.AddPeerToGroup(context.Background(), peerID, groupID)
require.NoError(t, err, "failed to add peer to group")
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
@@ -2483,7 +2542,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
require.NoError(t, err, "failed to add peer to account")
err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID)
err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
require.NoError(t, err, "failed to add peer to all group")
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
@@ -2569,7 +2628,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) {
assert.Len(t, groups, 1)
assert.Equal(t, groups[0].Name, "All")
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h")
err = store.AddPeerToGroup(context.Background(), peerID, "cfefqs706sqkneg59g4h")
require.NoError(t, err)
groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID)

View File

@@ -117,9 +117,11 @@ type Store interface {
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, peerId string, groupID string) error
RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error
RemovePeerFromAllGroups(ctx context.Context, peerID string) error
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
@@ -193,6 +195,7 @@ type Store interface {
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
}
const (
@@ -234,9 +237,9 @@ func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) type
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
// Attempt to migrate from JSON store to SQLite
// Attempt to migratePreAuto from JSON store to SQLite
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
log.WithContext(ctx).Errorf("failed to migratePreAuto filestore to SQLite: %v", err)
kind = types.FileStoreEngine
}
}
@@ -280,9 +283,9 @@ func checkFileStoreEngine(kind types.Engine, dataDir string) error {
return nil
}
// migrate migrates the SQLite database to the latest schema
func migrate(ctx context.Context, db *gorm.DB) error {
migrations := getMigrations(ctx)
// migratePreAuto migrates the SQLite database to the latest schema
func migratePreAuto(ctx context.Context, db *gorm.DB) error {
migrations := getMigrationsPreAuto(ctx)
for _, m := range migrations {
if err := m(db); err != nil {
@@ -293,7 +296,7 @@ func migrate(ctx context.Context, db *gorm.DB) error {
return nil
}
func getMigrations(ctx context.Context) []migrationFunc {
func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
@@ -329,6 +332,47 @@ func getMigrations(ctx context.Context) []migrationFunc {
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
},
}
} // migratePostAuto migrates the SQLite database to the latest schema
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
migrations := getMigrationsPostAuto(ctx)
for _, m := range migrations {
if err := m(db); err != nil {
return err
}
}
return nil
}
func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_ip", "account_id", "ip")
},
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
},
func(db *gorm.DB) error {
return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(id, value string) any {
return &types.GroupPeer{
GroupID: id,
PeerID: value,
}
})
},
func(db *gorm.DB) error {
return migration.MigrateEmbeddedToTable[types.Account, migration.LegacyAccountNetwork, types.Network](ctx, db, "id", func(obj migration.LegacyAccountNetwork) *types.Network {
return &types.Network{
AccountID: obj.AccountID,
Identifier: obj.Identifier,
Net: obj.Net,
Serial: obj.Serial,
Dns: obj.Dns,
}
})
},
}
}
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
@@ -391,7 +435,7 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
_, err := account.GetGroupAll()
if err != nil {
if err := account.AddAllGroup(); err != nil {
if err := account.AddAllGroup(false); err != nil {
return err
}
shouldSave = true
@@ -577,7 +621,7 @@ func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
sqliteStoreAccounts := len(store.GetAllAccounts(ctx))
if fsStoreAccounts != sqliteStoreAccounts {
return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d",
return fmt.Errorf("failed to migratePreAuto accounts from file to sqlite. Expected accounts: %d, got: %d",
fsStoreAccounts, sqliteStoreAccounts)
}

View File

@@ -52,4 +52,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D
INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0);
INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1');
INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','192.168.0.0','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);

View File

@@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO installations VALUES(1,'');

View File

@@ -67,13 +67,13 @@ type Account struct {
IsDomainPrimaryAccount bool
SetupKeys map[string]*SetupKey `gorm:"-"`
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
Network *Network `json:"-" gorm:"foreignKey:AccountID;references:id"`
Peers map[string]*nbpeer.Peer `gorm:"-"`
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
Users map[string]*User `gorm:"-"`
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*Group `gorm:"-"`
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[route.ID]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
@@ -1546,7 +1546,7 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st
}
// AddAllGroup to account object if it doesn't exist
func (a *Account) AddAllGroup() error {
func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
if len(a.Groups) == 0 {
allGroup := &Group{
ID: xid.New().String(),
@@ -1558,6 +1558,10 @@ func (a *Account) AddAllGroup() error {
}
a.Groups = map[string]*Group{allGroup.ID: allGroup}
if disableDefaultPolicy {
return nil
}
id := xid.New().String()
defaultPolicy := &Policy{

View File

@@ -53,6 +53,9 @@ type Config struct {
StoreConfig StoreConfig
ReverseProxy ReverseProxy
// disable default all-to-all policy
DisableDefaultPolicy bool
}
// GetAuthAudiences returns the audience from the http config and device authorization flow config

View File

@@ -26,7 +26,8 @@ type Group struct {
Issued string
// Peers list of the group
Peers []string `gorm:"serializer:json"`
Peers []string `gorm:"-"`
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
// Resources contains a list of resources in that group
Resources []Resource `gorm:"serializer:json"`
@@ -34,6 +35,29 @@ type Group struct {
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
type GroupPeer struct {
GroupID string `gorm:"primaryKey"`
PeerID string `gorm:"primaryKey"`
}
func (g *Group) LoadGroupPeers() {
g.Peers = make([]string, len(g.GroupPeers))
for i, peer := range g.GroupPeers {
g.Peers[i] = peer.PeerID
}
g.GroupPeers = []GroupPeer{}
}
func (g *Group) StoreGroupPeers() {
g.GroupPeers = make([]GroupPeer, len(g.Peers))
for i, peer := range g.Peers {
g.GroupPeers[i] = GroupPeer{
GroupID: g.ID,
PeerID: peer,
}
}
g.Peers = []string{}
}
// EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
@@ -46,13 +70,16 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an
func (g *Group) Copy() *Group {
group := &Group{
ID: g.ID,
AccountID: g.AccountID,
Name: g.Name,
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
Resources: make([]Resource, len(g.Resources)),
IntegrationReference: g.IntegrationReference,
}
copy(group.Peers, g.Peers)
copy(group.GroupPeers, g.GroupPeers)
copy(group.Resources, g.Resources)
return group
}

View File

@@ -1,6 +1,7 @@
package types
import (
"encoding/binary"
"math/rand"
"net"
"sync"
@@ -106,7 +107,8 @@ func ipToBytes(ip net.IP) []byte {
}
type Network struct {
Identifier string `json:"id"`
AccountID string `gorm:"primaryKey"`
Identifier string `gorm:"index"`
Net net.IPNet `gorm:"serializer:json"`
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
@@ -116,9 +118,13 @@ type Network struct {
Mu sync.Mutex `json:"-" gorm:"-"`
}
func (*Network) TableName() string {
return "account_networks"
}
// NewNetwork creates a new Network initializing it with a Serial=0
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
func NewNetwork() *Network {
func NewNetwork(accountID string) *Network {
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
sub, _ := n.Subnet(SubnetSize)
@@ -128,6 +134,7 @@ func NewNetwork() *Network {
intn := r.Intn(len(sub))
return &Network{
AccountID: accountID,
Identifier: xid.New().String(),
Net: sub[intn].IPNet,
Dns: "",
@@ -150,6 +157,7 @@ func (n *Network) CurrentSerial() uint64 {
func (n *Network) Copy() *Network {
return &Network{
AccountID: n.AccountID,
Identifier: n.Identifier,
Net: n.Net,
Dns: n.Dns,
@@ -161,24 +169,65 @@ func (n *Network) Copy() *Network {
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
takenIPMap := make(map[string]struct{})
takenIPMap[ipNet.IP.String()] = struct{}{}
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
totalIPs := uint32(1 << SubnetSize)
taken := make(map[uint32]struct{}, len(takenIps)+1)
taken[baseIP] = struct{}{} // reserve network IP
taken[baseIP+totalIPs-1] = struct{}{} // reserve broadcast IP
for _, ip := range takenIps {
takenIPMap[ip.String()] = struct{}{}
taken[ipToUint32(ip)] = struct{}{}
}
ips, _ := generateIPs(&ipNet, takenIPMap)
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
maxAttempts := (int(totalIPs) - len(taken)) / 100
if len(ips) == 0 {
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
for i := 0; i < maxAttempts; i++ {
offset := uint32(rng.Intn(int(totalIPs-2))) + 1
candidate := baseIP + offset
if _, exists := taken[candidate]; !exists {
return uint32ToIP(candidate), nil
}
}
// pick a random IP
s := rand.NewSource(time.Now().Unix())
r := rand.New(s)
intn := r.Intn(len(ips))
for offset := uint32(1); offset < totalIPs-1; offset++ {
candidate := baseIP + offset
if _, exists := taken[candidate]; !exists {
return uint32ToIP(candidate), nil
}
}
return ips[intn], nil
return nil, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", ipNet.String())
}
func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) {
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
totalIPs := uint32(1 << hostBits)
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
offset := uint32(rng.Intn(int(totalIPs-2))) + 1
candidate := baseIP + offset
return uint32ToIP(candidate), nil
}
func ipToUint32(ip net.IP) uint32 {
ip = ip.To4()
if len(ip) < 4 {
return 0
}
return binary.BigEndian.Uint32(ip)
}
func uint32ToIP(n uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, n)
return ip
}
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list

View File

@@ -8,7 +8,7 @@ import (
)
func TestNewNetwork(t *testing.T) {
network := NewNetwork()
network := NewNetwork("accountID")
// generated net should be a subnet of a larger 100.64.0.0/10 net
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}}

View File

@@ -35,7 +35,7 @@ type SetupKey struct {
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Key string
KeySecret string
KeySecret string `gorm:"index"`
Name string
Type SetupKeyType
CreatedAt time.Time

View File

@@ -56,7 +56,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
err = s.SaveAccount(context.Background(), account)
if err != nil {
@@ -103,7 +103,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users[mockTargetUserId] = &types.User{
Id: mockTargetUserId,
IsServiceUser: false,
@@ -131,7 +131,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users[mockTargetUserId] = &types.User{
Id: mockTargetUserId,
IsServiceUser: true,
@@ -163,7 +163,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -188,7 +188,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -213,7 +213,7 @@ func TestUser_DeletePAT(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users[mockUserID] = &types.User{
Id: mockUserID,
PATs: map[string]*types.PersonalAccessToken{
@@ -256,7 +256,7 @@ func TestUser_GetPAT(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users[mockUserID] = &types.User{
Id: mockUserID,
AccountID: mockAccountID,
@@ -296,7 +296,7 @@ func TestUser_GetAllPATs(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users[mockUserID] = &types.User{
Id: mockUserID,
AccountID: mockAccountID,
@@ -406,7 +406,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -453,7 +453,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -501,7 +501,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -532,7 +532,7 @@ func TestUser_InviteNewUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -639,7 +639,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users[mockServiceUserID] = tt.serviceUser
err = store.SaveAccount(context.Background(), account)
@@ -678,7 +678,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -705,7 +705,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
targetId := "user2"
account.Users[targetId] = &types.User{
@@ -792,7 +792,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
targetId := "user2"
account.Users[targetId] = &types.User{
@@ -952,7 +952,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -988,7 +988,7 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users["normal_user1"] = types.NewRegularUser("normal_user1")
account.Users["normal_user2"] = types.NewRegularUser("normal_user2")
@@ -1030,7 +1030,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
externalUser := &types.User{
Id: "externalUser",
Role: types.UserRoleUser,
@@ -1098,7 +1098,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users[mockServiceUserID] = &types.User{
Id: mockServiceUserID,
Role: "user",
@@ -1132,7 +1132,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users[mockServiceUserID] = &types.User{
Id: mockServiceUserID,
Role: "user",
@@ -1499,7 +1499,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
}
t.Cleanup(cleanup)
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "")
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", false)
targetId := "user2"
account1.Users[targetId] = &types.User{
Id: targetId,
@@ -1508,7 +1508,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
}
require.NoError(t, s.SaveAccount(context.Background(), account1))
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "")
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", false)
require.NoError(t, s.SaveAccount(context.Background(), account2))
permissionsManager := permissions.NewManager(s)
@@ -1535,7 +1535,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
}
t.Cleanup(cleanup)
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "")
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", false)
account1.Settings.RegularUsersViewBlocked = false
account1.Users["blocked-user"] = &types.User{
Id: "blocked-user",
@@ -1557,7 +1557,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
}
require.NoError(t, store.SaveAccount(context.Background(), account1))
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "")
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", false)
account2.Users["settings-blocked-user"] = &types.User{
Id: "settings-blocked-user",
Role: types.UserRoleUser,

View File

@@ -130,7 +130,7 @@ repo_gpgcheck=1
EOF
}
add_aur_repo() {
install_aur_package() {
INSTALL_PKGS="git base-devel go"
REMOVE_PKGS=""
@@ -154,8 +154,10 @@ add_aur_repo() {
cd netbird-ui && makepkg -sri --noconfirm
fi
# Clean up the installed packages
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
if [ -n "$REMOVE_PKGS" ]; then
# Clean up the installed packages
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
fi
}
prepare_tun_module() {
@@ -277,7 +279,9 @@ install_netbird() {
;;
pacman)
${SUDO} pacman -Syy
add_aur_repo
install_aur_package
# in-line with the docs at https://wiki.archlinux.org/title/Netbird
${SUDO} systemctl enable --now netbird@main.service
;;
pkg)
# Check if the package is already installed
@@ -494,4 +498,4 @@ case "$UPDATE_FLAG" in
;;
*)
install_netbird
esac
esac