mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Compare commits
47 Commits
feature/li
...
feature/ne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
261d1e094a | ||
|
|
bcab5cbbee | ||
|
|
3931958499 | ||
|
|
914e58ac75 | ||
|
|
7baeea3d9d | ||
|
|
64e618f1ad | ||
|
|
5802dcaf80 | ||
|
|
b329397d06 | ||
|
|
76b180a741 | ||
|
|
8db9065ed9 | ||
|
|
1f79fc0728 | ||
|
|
bdd0b1cf02 | ||
|
|
e6ac248aee | ||
|
|
376394f7f9 | ||
|
|
542dbdb41c | ||
|
|
982b9604ee | ||
|
|
f2990e2fbc | ||
|
|
dfb47d5545 | ||
|
|
8e0b8f20a2 | ||
|
|
8a42528664 | ||
|
|
a8cba921e1 | ||
|
|
fee36b0663 | ||
|
|
dfad334780 | ||
|
|
d25da87957 | ||
|
|
13213d954d | ||
|
|
6fb61c7cf5 | ||
|
|
459db2ba4f | ||
|
|
e78b7dd058 | ||
|
|
7132642e4c | ||
|
|
22a944b157 | ||
|
|
005937ae77 | ||
|
|
5fab2d019a | ||
|
|
36155f8de1 | ||
|
|
d06831dd2f | ||
|
|
e23282b92c | ||
|
|
57961afe95 | ||
|
|
22678bce7f | ||
|
|
6c633497bc | ||
|
|
6922826919 | ||
|
|
56a1a75e3f | ||
|
|
d9402168ad | ||
|
|
dbdef04b9e | ||
|
|
29cbfe8467 | ||
|
|
6ce8643368 | ||
|
|
07d1ad35fc | ||
|
|
ef6cd36f1a | ||
|
|
c1c71b6d39 |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.19"
|
||||
SIGN_PIPE_VER: "v0.0.20"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -134,6 +134,7 @@ jobs:
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
|
||||
run: |
|
||||
set -x
|
||||
@@ -180,6 +181,7 @@ jobs:
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
<br>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
@@ -29,13 +32,13 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||
New: NetBird Kubernetes Operator
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -319,10 +319,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
*input.WireguardPort, config.WgPort)
|
||||
config.WgPort = *input.WireguardPort
|
||||
updated = true
|
||||
} else if config.WgPort == 0 {
|
||||
config.WgPort = iface.DefaultWgPort
|
||||
log.Infof("using default Wireguard port %d", config.WgPort)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -526,17 +525,13 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
|
||||
|
||||
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
|
||||
func freePort(initPort int) (int, error) {
|
||||
addr := net.UDPAddr{}
|
||||
if initPort == 0 {
|
||||
initPort = iface.DefaultWgPort
|
||||
}
|
||||
|
||||
addr.Port = initPort
|
||||
addr := net.UDPAddr{Port: initPort}
|
||||
|
||||
conn, err := net.ListenUDP("udp", &addr)
|
||||
if err == nil {
|
||||
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
|
||||
closeConnWithLog(conn)
|
||||
return initPort, nil
|
||||
return returnPort, nil
|
||||
}
|
||||
|
||||
// if the port is already in use, ask the system for a free port
|
||||
|
||||
@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "not provided, fallback to default",
|
||||
name: "when port is 0 use random port",
|
||||
port: 0,
|
||||
want: 51820,
|
||||
shouldMatch: true,
|
||||
want: 0,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "provided and available",
|
||||
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
@@ -39,6 +39,14 @@ func Test_freePort(t *testing.T) {
|
||||
_ = c1.Close()
|
||||
}(c1)
|
||||
|
||||
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
|
||||
tests[1].port++
|
||||
tests[1].want++
|
||||
}
|
||||
|
||||
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
|
||||
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -1476,7 +1476,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -28,7 +28,10 @@ func (n Nexthop) String() string {
|
||||
if n.Intf == nil {
|
||||
return n.IP.String()
|
||||
}
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
if n.IP.IsValid() {
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
return fmt.Sprintf("no-ip @ %d (%s)", n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
|
||||
type wgIface interface {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -158,6 +158,7 @@ message UpResponse {}
|
||||
|
||||
message StatusRequest{
|
||||
bool getFullPeerStatus = 1;
|
||||
bool shouldRunProbes = 2;
|
||||
}
|
||||
|
||||
message StatusResponse{
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v5.29.3
|
||||
// source: daemon.proto
|
||||
|
||||
package proto
|
||||
|
||||
@@ -11,8 +15,31 @@ import (
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
DaemonService_Login_FullMethodName = "/daemon.DaemonService/Login"
|
||||
DaemonService_WaitSSOLogin_FullMethodName = "/daemon.DaemonService/WaitSSOLogin"
|
||||
DaemonService_Up_FullMethodName = "/daemon.DaemonService/Up"
|
||||
DaemonService_Status_FullMethodName = "/daemon.DaemonService/Status"
|
||||
DaemonService_Down_FullMethodName = "/daemon.DaemonService/Down"
|
||||
DaemonService_GetConfig_FullMethodName = "/daemon.DaemonService/GetConfig"
|
||||
DaemonService_ListNetworks_FullMethodName = "/daemon.DaemonService/ListNetworks"
|
||||
DaemonService_SelectNetworks_FullMethodName = "/daemon.DaemonService/SelectNetworks"
|
||||
DaemonService_DeselectNetworks_FullMethodName = "/daemon.DaemonService/DeselectNetworks"
|
||||
DaemonService_ForwardingRules_FullMethodName = "/daemon.DaemonService/ForwardingRules"
|
||||
DaemonService_DebugBundle_FullMethodName = "/daemon.DaemonService/DebugBundle"
|
||||
DaemonService_GetLogLevel_FullMethodName = "/daemon.DaemonService/GetLogLevel"
|
||||
DaemonService_SetLogLevel_FullMethodName = "/daemon.DaemonService/SetLogLevel"
|
||||
DaemonService_ListStates_FullMethodName = "/daemon.DaemonService/ListStates"
|
||||
DaemonService_CleanState_FullMethodName = "/daemon.DaemonService/CleanState"
|
||||
DaemonService_DeleteState_FullMethodName = "/daemon.DaemonService/DeleteState"
|
||||
DaemonService_SetNetworkMapPersistence_FullMethodName = "/daemon.DaemonService/SetNetworkMapPersistence"
|
||||
DaemonService_TracePacket_FullMethodName = "/daemon.DaemonService/TracePacket"
|
||||
DaemonService_SubscribeEvents_FullMethodName = "/daemon.DaemonService/SubscribeEvents"
|
||||
DaemonService_GetEvents_FullMethodName = "/daemon.DaemonService/GetEvents"
|
||||
)
|
||||
|
||||
// DaemonServiceClient is the client API for DaemonService service.
|
||||
//
|
||||
@@ -53,7 +80,7 @@ type DaemonServiceClient interface {
|
||||
// SetNetworkMapPersistence enables or disables network map persistence
|
||||
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
|
||||
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
|
||||
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error)
|
||||
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error)
|
||||
GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error)
|
||||
}
|
||||
|
||||
@@ -66,8 +93,9 @@ func NewDaemonServiceClient(cc grpc.ClientConnInterface) DaemonServiceClient {
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(LoginResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/Login", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_Login_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -75,8 +103,9 @@ func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLoginRequest, opts ...grpc.CallOption) (*WaitSSOLoginResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(WaitSSOLoginResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitSSOLogin", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_WaitSSOLogin_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -84,8 +113,9 @@ func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLogin
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(UpResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/Up", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_Up_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -93,8 +123,9 @@ func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grp
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(StatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/Status", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_Status_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -102,8 +133,9 @@ func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opt
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(DownResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/Down", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_Down_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -111,8 +143,9 @@ func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ..
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(GetConfigResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetConfig", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_GetConfig_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -120,8 +153,9 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(ListNetworksResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_ListNetworks_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -129,8 +163,9 @@ func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworks
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(SelectNetworksResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_SelectNetworks_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -138,8 +173,9 @@ func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetw
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(SelectNetworksResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_DeselectNetworks_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -147,8 +183,9 @@ func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNe
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(ForwardingRulesResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ForwardingRules", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_ForwardingRules_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -156,8 +193,9 @@ func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequ
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(DebugBundleResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_DebugBundle_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -165,8 +203,9 @@ func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRe
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(GetLogLevelResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetLogLevel", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_GetLogLevel_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -174,8 +213,9 @@ func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRe
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(SetLogLevelResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_SetLogLevel_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -183,8 +223,9 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(ListStatesResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_ListStates_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -192,8 +233,9 @@ func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequ
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(CleanStateResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_CleanState_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -201,8 +243,9 @@ func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequ
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(DeleteStateResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_DeleteState_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -210,8 +253,9 @@ func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRe
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(SetNetworkMapPersistenceResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_SetNetworkMapPersistence_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -219,20 +263,22 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(TracePacketResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_TracePacket_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...)
|
||||
func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_SubscribeEvents_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &daemonServiceSubscribeEventsClient{stream}
|
||||
x := &grpc.GenericClientStream[SubscribeRequest, SystemEvent]{ClientStream: stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -242,26 +288,13 @@ func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *Subscribe
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type DaemonService_SubscribeEventsClient interface {
|
||||
Recv() (*SystemEvent, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type daemonServiceSubscribeEventsClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *daemonServiceSubscribeEventsClient) Recv() (*SystemEvent, error) {
|
||||
m := new(SystemEvent)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_SubscribeEventsClient = grpc.ServerStreamingClient[SystemEvent]
|
||||
|
||||
func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(GetEventsResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetEvents", in, out, opts...)
|
||||
err := c.cc.Invoke(ctx, DaemonService_GetEvents_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -270,7 +303,7 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
// for forward compatibility.
|
||||
type DaemonServiceServer interface {
|
||||
// Login uses setup key to prepare configuration for the daemon.
|
||||
Login(context.Context, *LoginRequest) (*LoginResponse, error)
|
||||
@@ -307,14 +340,17 @@ type DaemonServiceServer interface {
|
||||
// SetNetworkMapPersistence enables or disables network map persistence
|
||||
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
|
||||
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
|
||||
SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error
|
||||
SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error
|
||||
GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
// UnimplementedDaemonServiceServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedDaemonServiceServer struct {
|
||||
}
|
||||
// UnimplementedDaemonServiceServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedDaemonServiceServer struct{}
|
||||
|
||||
func (UnimplementedDaemonServiceServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Login not implemented")
|
||||
@@ -370,13 +406,14 @@ func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context
|
||||
func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error {
|
||||
func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to DaemonServiceServer will
|
||||
@@ -386,6 +423,13 @@ type UnsafeDaemonServiceServer interface {
|
||||
}
|
||||
|
||||
func RegisterDaemonServiceServer(s grpc.ServiceRegistrar, srv DaemonServiceServer) {
|
||||
// If the following call pancis, it indicates UnimplementedDaemonServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&DaemonService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
@@ -399,7 +443,7 @@ func _DaemonService_Login_Handler(srv interface{}, ctx context.Context, dec func
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/Login",
|
||||
FullMethod: DaemonService_Login_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).Login(ctx, req.(*LoginRequest))
|
||||
@@ -417,7 +461,7 @@ func _DaemonService_WaitSSOLogin_Handler(srv interface{}, ctx context.Context, d
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/WaitSSOLogin",
|
||||
FullMethod: DaemonService_WaitSSOLogin_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).WaitSSOLogin(ctx, req.(*WaitSSOLoginRequest))
|
||||
@@ -435,7 +479,7 @@ func _DaemonService_Up_Handler(srv interface{}, ctx context.Context, dec func(in
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/Up",
|
||||
FullMethod: DaemonService_Up_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).Up(ctx, req.(*UpRequest))
|
||||
@@ -453,7 +497,7 @@ func _DaemonService_Status_Handler(srv interface{}, ctx context.Context, dec fun
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/Status",
|
||||
FullMethod: DaemonService_Status_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).Status(ctx, req.(*StatusRequest))
|
||||
@@ -471,7 +515,7 @@ func _DaemonService_Down_Handler(srv interface{}, ctx context.Context, dec func(
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/Down",
|
||||
FullMethod: DaemonService_Down_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).Down(ctx, req.(*DownRequest))
|
||||
@@ -489,7 +533,7 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/GetConfig",
|
||||
FullMethod: DaemonService_GetConfig_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).GetConfig(ctx, req.(*GetConfigRequest))
|
||||
@@ -507,7 +551,7 @@ func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, d
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/ListNetworks",
|
||||
FullMethod: DaemonService_ListNetworks_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest))
|
||||
@@ -525,7 +569,7 @@ func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context,
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/SelectNetworks",
|
||||
FullMethod: DaemonService_SelectNetworks_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest))
|
||||
@@ -543,7 +587,7 @@ func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Contex
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/DeselectNetworks",
|
||||
FullMethod: DaemonService_DeselectNetworks_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest))
|
||||
@@ -561,7 +605,7 @@ func _DaemonService_ForwardingRules_Handler(srv interface{}, ctx context.Context
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/ForwardingRules",
|
||||
FullMethod: DaemonService_ForwardingRules_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).ForwardingRules(ctx, req.(*EmptyRequest))
|
||||
@@ -579,7 +623,7 @@ func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, de
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/DebugBundle",
|
||||
FullMethod: DaemonService_DebugBundle_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest))
|
||||
@@ -597,7 +641,7 @@ func _DaemonService_GetLogLevel_Handler(srv interface{}, ctx context.Context, de
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/GetLogLevel",
|
||||
FullMethod: DaemonService_GetLogLevel_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest))
|
||||
@@ -615,7 +659,7 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/SetLogLevel",
|
||||
FullMethod: DaemonService_SetLogLevel_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest))
|
||||
@@ -633,7 +677,7 @@ func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/ListStates",
|
||||
FullMethod: DaemonService_ListStates_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest))
|
||||
@@ -651,7 +695,7 @@ func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/CleanState",
|
||||
FullMethod: DaemonService_CleanState_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest))
|
||||
@@ -669,7 +713,7 @@ func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, de
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/DeleteState",
|
||||
FullMethod: DaemonService_DeleteState_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest))
|
||||
@@ -687,7 +731,7 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence",
|
||||
FullMethod: DaemonService_SetNetworkMapPersistence_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest))
|
||||
@@ -705,7 +749,7 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/TracePacket",
|
||||
FullMethod: DaemonService_TracePacket_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest))
|
||||
@@ -718,21 +762,11 @@ func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerS
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(DaemonServiceServer).SubscribeEvents(m, &daemonServiceSubscribeEventsServer{stream})
|
||||
return srv.(DaemonServiceServer).SubscribeEvents(m, &grpc.GenericServerStream[SubscribeRequest, SystemEvent]{ServerStream: stream})
|
||||
}
|
||||
|
||||
type DaemonService_SubscribeEventsServer interface {
|
||||
Send(*SystemEvent) error
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type daemonServiceSubscribeEventsServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *daemonServiceSubscribeEventsServer) Send(m *SystemEvent) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_SubscribeEventsServer = grpc.ServerStreamingServer[SystemEvent]
|
||||
|
||||
func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetEventsRequest)
|
||||
@@ -744,7 +778,7 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/GetEvents",
|
||||
FullMethod: DaemonService_GetEvents_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).GetEvents(ctx, req.(*GetEventsRequest))
|
||||
|
||||
@@ -707,7 +707,9 @@ func (s *Server) Status(
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
if msg.GetFullPeerStatus {
|
||||
s.runProbes()
|
||||
if msg.ShouldRunProbes {
|
||||
s.runProbes()
|
||||
}
|
||||
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
|
||||
@@ -206,7 +206,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -879,7 +879,7 @@ func (s *serviceClient) onUpdateAvailable() {
|
||||
func (s *serviceClient) onSessionExpire() {
|
||||
s.sendNotification = true
|
||||
if s.sendNotification {
|
||||
s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
|
||||
go s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
|
||||
s.sendNotification = false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAI
|
||||
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
|
||||
NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=${NETBIRD_MGMT_DISABLE_DEFAULT_POLICY:-false}
|
||||
|
||||
# Signal
|
||||
NETBIRD_SIGNAL_PROTOCOL="http"
|
||||
@@ -139,3 +140,4 @@ export NETBIRD_RELAY_PORT
|
||||
export NETBIRD_RELAY_ENDPOINT
|
||||
export NETBIRD_RELAY_AUTH_SECRET
|
||||
export NETBIRD_RELAY_TAG
|
||||
export NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
@@ -791,7 +791,6 @@ services:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
- '8080:8080'
|
||||
volumes:
|
||||
- netbird_caddy_data:/data
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
"0.0.0.0/0"
|
||||
]
|
||||
},
|
||||
"DisableDefaultPolicy": $NETBIRD_MGMT_DISABLE_DEFAULT_POLICY,
|
||||
"Datadir": "",
|
||||
"DataStoreEncryptionKey": "$NETBIRD_DATASTORE_ENC_KEY",
|
||||
"StoreConfig": {
|
||||
|
||||
@@ -92,7 +92,8 @@ NETBIRD_LETSENCRYPT_EMAIL=""
|
||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
|
||||
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
|
||||
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
||||
|
||||
# Disable default all-to-all policy for new accounts
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=false
|
||||
# -------------------------------------------
|
||||
# Relay settings
|
||||
# -------------------------------------------
|
||||
|
||||
@@ -29,3 +29,4 @@ NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
|
||||
NETBIRD_RELAY_PORT=33445
|
||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
|
||||
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0
|
||||
NETBIRD_MGMT_DISABLE_DEFAULT_POLICY=$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY
|
||||
|
||||
@@ -100,7 +100,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ var (
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
proxyController := integrations.NewController(store)
|
||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager)
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -102,6 +102,20 @@ type DefaultAccountManager struct {
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
|
||||
disableDefaultPolicy bool
|
||||
}
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
switch {
|
||||
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
||||
strings.Contains(err.Error(), "Error 1062 (23000)"),
|
||||
strings.Contains(err.Error(), "UNIQUE constraint failed"):
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||
@@ -170,6 +184,7 @@ func BuildManager(
|
||||
proxyController port_forwarding.Controller,
|
||||
settingsManager settings.Manager,
|
||||
permissionsManager permissions.Manager,
|
||||
disableDefaultPolicy bool,
|
||||
) (*DefaultAccountManager, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -195,6 +210,7 @@ func BuildManager(
|
||||
proxyController: proxyController,
|
||||
settingsManager: settingsManager,
|
||||
permissionsManager: permissionsManager,
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
@@ -543,7 +559,7 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
case statusErr.Type() == status.NotFound:
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain)
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
|
||||
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
default:
|
||||
@@ -1657,25 +1673,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
|
||||
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
|
||||
}
|
||||
|
||||
labelMap := ConvertSliceToMap(existingLabels)
|
||||
newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||
}
|
||||
|
||||
if newLabel == "" {
|
||||
return "", fmt.Errorf("failed to get new host label: %w", err)
|
||||
}
|
||||
|
||||
return newLabel, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
|
||||
if err != nil {
|
||||
@@ -1688,10 +1685,10 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
network := types.NewNetwork(accountID)
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[route.ID]*route.Route)
|
||||
@@ -1731,7 +1728,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
},
|
||||
}
|
||||
|
||||
if err := acc.AddAllGroup(); err != nil {
|
||||
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
@@ -1795,7 +1792,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
continue
|
||||
}
|
||||
|
||||
network := types.NewNetwork()
|
||||
network := types.NewNetwork(accountId)
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[route.ID]*route.Route)
|
||||
@@ -1833,7 +1830,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
},
|
||||
}
|
||||
|
||||
if err := newAccount.AddAllGroup(); err != nil {
|
||||
if err := newAccount.AddAllGroup(am.disableDefaultPolicy); err != nil {
|
||||
return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
|
||||
}
|
||||
|
||||
|
||||
@@ -373,7 +373,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -398,7 +398,7 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@@ -640,7 +640,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, false)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
@@ -793,7 +793,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1208,6 +1208,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
// Ensure that we do not receive an update message before the policy is deleted
|
||||
time.Sleep(time.Second)
|
||||
select {
|
||||
case <-updMsg:
|
||||
t.Logf("received addPeer update message before policy deletion")
|
||||
default:
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1232,9 +1240,10 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
AccountID: account.Id,
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
@@ -1664,9 +1673,10 @@ func TestAccount_Copy(t *testing.T) {
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
@@ -2608,6 +2618,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
@@ -2615,11 +2626,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
account := &types.Account{
|
||||
Id: "accountID",
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
|
||||
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"},
|
||||
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
|
||||
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
|
||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"},
|
||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"},
|
||||
"peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"},
|
||||
"peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"},
|
||||
"peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
|
||||
@@ -2879,7 +2890,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3139,11 +3150,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 7, 20, 10, 80},
|
||||
{"Small", 50, 5, 7, 20, 5, 80},
|
||||
{"Medium", 500, 100, 5, 40, 30, 140},
|
||||
{"Large", 5000, 200, 80, 120, 140, 390},
|
||||
{"Small single", 50, 10, 7, 20, 10, 80},
|
||||
{"Medium single", 500, 10, 5, 40, 20, 85},
|
||||
{"Small single", 50, 10, 7, 20, 6, 80},
|
||||
{"Medium single", 500, 10, 5, 40, 15, 85},
|
||||
{"Large 5", 5000, 15, 80, 120, 80, 200},
|
||||
}
|
||||
|
||||
@@ -3335,11 +3346,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId}
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId}
|
||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -217,7 +217,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -267,7 +267,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, false)
|
||||
|
||||
account.Users[dnsRegularUserID] = &types.User{
|
||||
Id: dnsRegularUserID,
|
||||
|
||||
@@ -127,7 +127,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "")
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
|
||||
@@ -265,20 +265,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.AddPeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -288,7 +278,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.AddPeerToGroup(ctx, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -347,20 +337,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.RemovePeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -370,7 +350,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,14 +2,19 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -18,8 +23,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -369,7 +376,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
@@ -733,3 +740,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AddPeerToGroup(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
acc, err := createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.AddPeerToGroup(context.Background(), strconv.Itoa(i), acc.GroupsG[0].ID)
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func Test_AddPeerToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i))
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func Test_AddPeerAndAddToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
peer := &peer2.Peer{
|
||||
ID: strconv.Itoa(i),
|
||||
AccountID: accountID,
|
||||
DNSLabel: "peer" + strconv.Itoa(i),
|
||||
IP: uint32ToIP(uint32(i)),
|
||||
}
|
||||
|
||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
}
|
||||
err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddPeer failed for peer %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func uint32ToIP(n uint32) net.IP {
|
||||
ip := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(ip, n)
|
||||
return ip
|
||||
}
|
||||
|
||||
func Test_IncrementNetworkSerial(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddPeer failed for peer %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial)
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
handler := initAccountsTestData(t, &types.Account{
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Network: types.NewNetwork(),
|
||||
Network: types.NewNetwork(accountID),
|
||||
Users: map[string]*types.User{
|
||||
adminUser.Id: adminUser,
|
||||
},
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
package testing_tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
@@ -138,7 +137,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -444,7 +444,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
|
||||
if err != nil {
|
||||
cleanup()
|
||||
|
||||
@@ -211,7 +211,7 @@ func startServer(
|
||||
port_forwarding.NewControllerMock(),
|
||||
settingsMockManager,
|
||||
permissionsManager,
|
||||
)
|
||||
false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating an account manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -10,13 +10,25 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type LegacyAccountNetwork struct {
|
||||
AccountID string `gorm:"column:id"`
|
||||
Identifier string `gorm:"column:network_identifier"`
|
||||
Net net.IPNet `gorm:"column:network_net;serializer:json"`
|
||||
Dns string `gorm:"column:network_dns"`
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64 `gorm:"column:network_serial"`
|
||||
}
|
||||
|
||||
func GetColumnName(db *gorm.DB, column string) string {
|
||||
if db.Name() == "mysql" {
|
||||
return fmt.Sprintf("`%s`", column)
|
||||
@@ -39,6 +51,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
|
||||
return nil
|
||||
}
|
||||
|
||||
if !db.Migrator().HasColumn(&model, oldColumnName) {
|
||||
log.WithContext(ctx).Debugf("Column for %T does not exist, no migration needed", oldColumnName)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(model)
|
||||
if err != nil {
|
||||
@@ -373,3 +390,188 @@ func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error
|
||||
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
|
||||
var model T
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
if err := stmt.Parse(&model); err != nil {
|
||||
return fmt.Errorf("failed to parse model schema: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
dialect := db.Dialector.Name()
|
||||
|
||||
var columnClause string
|
||||
if dialect == "mysql" {
|
||||
var withLength []string
|
||||
for _, col := range columns {
|
||||
if col == "ip" || col == "dns_label" {
|
||||
withLength = append(withLength, fmt.Sprintf("%s(64)", col))
|
||||
} else {
|
||||
withLength = append(withLength, col)
|
||||
}
|
||||
}
|
||||
columnClause = strings.Join(withLength, ", ")
|
||||
} else {
|
||||
columnClause = strings.Join(columns, ", ")
|
||||
}
|
||||
|
||||
createStmt := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, columnClause)
|
||||
if dialect == "postgres" || dialect == "sqlite" {
|
||||
createStmt = strings.Replace(createStmt, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("executing index creation: %s", createStmt)
|
||||
if err := db.Exec(createStmt).Error; err != nil {
|
||||
return fmt.Errorf("failed to create index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(id string, value string) any) error {
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
if !db.Migrator().HasColumn(&model, columnName) {
|
||||
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
var rows []map[string]any
|
||||
if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil {
|
||||
return fmt.Errorf("find rows: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
jsonValue, ok := row[columnName].(string)
|
||||
if !ok || jsonValue == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var data []string
|
||||
if err := json.Unmarshal([]byte(jsonValue), &data); err != nil {
|
||||
return fmt.Errorf("unmarshal json: %w", err)
|
||||
}
|
||||
|
||||
for _, value := range data {
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
DoNothing: true, // this needs to be removed when the cleanup is enabled
|
||||
}).Create(
|
||||
mapperFunc(row["id"].(string), value),
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Todo: Enable this after we are sure that every thing works as expected and we do not need to rollback anymore
|
||||
// if err := tx.Migrator().DropColumn(&model, columnName); err != nil {
|
||||
// return fmt.Errorf("drop column %s: %w", columnName, err)
|
||||
// }
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateEmbeddedToTable[T any, S any, U any](ctx context.Context, db *gorm.DB, pkey string, mapperFunc func(obj S) *U) error {
|
||||
var model T
|
||||
|
||||
log.WithContext(ctx).Debugf("Migrating embedded fields from %T to separate table", model)
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
var legacyRows []S
|
||||
if err := tx.Table(tableName).Find(&legacyRows).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("Failed to read legacy accounts: %v", err)
|
||||
return fmt.Errorf("failed to read legacy accounts: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range legacyRows {
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
DoNothing: true, // this needs to be removed when the cleanup is enabled
|
||||
}).Create(
|
||||
mapperFunc(row),
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert id %v: %w", row, err)
|
||||
}
|
||||
}
|
||||
|
||||
// cols, err := getColumnNamesFromStruct(new(S))
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to extract column names: %w", err)
|
||||
// }
|
||||
|
||||
// for _, col := range cols {
|
||||
// if col == pkey {
|
||||
// continue
|
||||
// }
|
||||
// if err := tx.Migrator().DropColumn(&model, col); err != nil {
|
||||
// return fmt.Errorf("failed to drop column %s: %w", col, err)
|
||||
// }
|
||||
// }
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("Migration of embedded fields %T from table %s into seperte table completed", new(S), tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getColumnNamesFromStruct[T any](model T) ([]string, error) {
|
||||
val := reflect.TypeOf(model)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
var cols []string
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Field(i)
|
||||
if field.Name == "ID" {
|
||||
continue // skip primary key
|
||||
}
|
||||
tag := field.Tag.Get("gorm")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
// Look for gorm:"column:..."
|
||||
for _, part := range strings.Split(tag, ";") {
|
||||
if strings.HasPrefix(part, "column:") {
|
||||
cols = append(cols, strings.TrimPrefix(part, "column:"))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
@@ -779,7 +779,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (store.Store, error) {
|
||||
@@ -848,7 +848,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
|
||||
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
|
||||
|
||||
|
||||
@@ -15,13 +15,14 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -234,14 +235,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
if peer.Name != update.Name {
|
||||
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
|
||||
var newLabel string
|
||||
newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newLabel, err := types.GetPeerHostLabel(update.Name, existingLabels)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
|
||||
peer.Name = update.Name
|
||||
@@ -363,25 +360,20 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peer groups: %w", err)
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
group.RemovePeer(peerID)
|
||||
err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save group: %w", err)
|
||||
}
|
||||
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
|
||||
return fmt.Errorf("failed to remove peer from groups: %w", err)
|
||||
}
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -463,233 +455,247 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
upperKey := strings.ToUpper(setupKey)
|
||||
hashedKey := sha256.Sum256([]byte(upperKey))
|
||||
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
||||
var accountID string
|
||||
var err error
|
||||
addedByUser := false
|
||||
if len(userID) > 0 {
|
||||
addedByUser = true
|
||||
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
} else {
|
||||
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
addedByUser := len(userID) > 0
|
||||
|
||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||
// and the peer disconnects with a timeout and tries to register again.
|
||||
// We just check if this machine has been registered before and reject the second registration.
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key)
|
||||
_, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
|
||||
if err == nil {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
|
||||
}
|
||||
|
||||
opEvent := &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
AccountID: accountID,
|
||||
}
|
||||
|
||||
var newPeer *nbpeer.Peer
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var ephemeral bool
|
||||
var groupsToAdd []string
|
||||
var allowExtraDNSLabels bool
|
||||
if addedByUser {
|
||||
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user groups: %w", err)
|
||||
}
|
||||
groupsToAdd = user.AutoGroups
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
} else {
|
||||
// Validate the setup key
|
||||
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
|
||||
if !sk.IsValid() {
|
||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
groupsToAdd = sk.AutoGroups
|
||||
ephemeral = sk.Ephemeral
|
||||
setupKeyID = sk.Id
|
||||
setupKeyName = sk.Name
|
||||
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||
|
||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||
}
|
||||
}
|
||||
|
||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var ephemeral bool
|
||||
var groupsToAdd []string
|
||||
var allowExtraDNSLabels bool
|
||||
var accountID string
|
||||
if addedByUser {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to get user groups: %w", err)
|
||||
}
|
||||
|
||||
freeIP, err := getFreeIP(ctx, transaction, accountID)
|
||||
groupsToAdd = user.AutoGroups
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
accountID = user.AccountID
|
||||
} else {
|
||||
// Validate the setup key
|
||||
sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get free IP: %w", err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
|
||||
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
||||
// we will check key twice for early return
|
||||
if !sk.IsValid() {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
newPeer = &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Key: peer.Key,
|
||||
IP: freeIP,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: freeLabel,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: ®istrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
InactivityExpirationEnabled: addedByUser,
|
||||
ExtraDNSLabels: peer.ExtraDNSLabels,
|
||||
AllowExtraDNSLabels: allowExtraDNSLabels,
|
||||
}
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
opEvent.InitiatorID = sk.Id
|
||||
opEvent.Activity = activity.PeerAddedWithSetupKey
|
||||
groupsToAdd = sk.AutoGroups
|
||||
ephemeral = sk.Ephemeral
|
||||
setupKeyID = sk.Id
|
||||
setupKeyName = sk.Name
|
||||
allowExtraDNSLabels = sk.AllowExtraDNSLabels
|
||||
accountID = sk.AccountID
|
||||
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 {
|
||||
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels")
|
||||
}
|
||||
}
|
||||
opEvent.AccountID = accountID
|
||||
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
||||
} else {
|
||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||
newPeer.Location.CityName = location.City.Names.En
|
||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||
if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||
|
||||
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, g := range groupsToAdd {
|
||||
err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if addedByUser {
|
||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||
}
|
||||
} else {
|
||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||
return nil
|
||||
})
|
||||
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
||||
}
|
||||
|
||||
registrationTime := time.Now().UTC()
|
||||
newPeer = &nbpeer.Peer{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Key: peer.Key,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Meta.Hostname,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: ®istrationTime,
|
||||
CreatedAt: registrationTime,
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
Ephemeral: ephemeral,
|
||||
Location: peer.Location,
|
||||
InactivityExpirationEnabled: addedByUser,
|
||||
ExtraDNSLabels: peer.ExtraDNSLabels,
|
||||
AllowExtraDNSLabels: allowExtraDNSLabels,
|
||||
}
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
||||
} else {
|
||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||
newPeer.Location.CityName = location.City.Names.En
|
||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||
}
|
||||
}
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed getting network: %w", err)
|
||||
}
|
||||
|
||||
maxAttempts := 10
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
var freeIP net.IP
|
||||
freeIP, err = types.AllocateRandomPeerIP(network.Net)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err)
|
||||
}
|
||||
|
||||
var freeLabel string
|
||||
freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
|
||||
newPeer.DNSLabel = freeLabel
|
||||
newPeer.IP = freeIP
|
||||
|
||||
// unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
// defer func() {
|
||||
// if unlock != nil {
|
||||
// unlock()
|
||||
// }
|
||||
// }()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, g := range groupsToAdd {
|
||||
err = transaction.AddPeerToGroup(ctx, newPeer.ID, g)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if addedByUser {
|
||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
|
||||
}
|
||||
} else {
|
||||
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
}
|
||||
|
||||
// we validate at the end to not block the setup key for too long
|
||||
if !sk.IsValid() {
|
||||
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment setup key usage: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
// unlock()
|
||||
// unlock = nil
|
||||
break
|
||||
}
|
||||
|
||||
if isUniqueConstraintError(err) {
|
||||
// unlock()
|
||||
// unlock = nil
|
||||
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
|
||||
}
|
||||
|
||||
if newPeer == nil {
|
||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||
}
|
||||
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
|
||||
if !addedByUser {
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
if updateAccountPeers {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
}
|
||||
|
||||
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) {
|
||||
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID)
|
||||
func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) {
|
||||
ip = ip.To4()
|
||||
|
||||
dnsName, err := nbdns.GetParsedDomainLabel(peerHostName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
||||
return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err)
|
||||
}
|
||||
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID)
|
||||
_, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting network: %w", err)
|
||||
//nolint:nilerr
|
||||
return dnsName, nil
|
||||
}
|
||||
|
||||
nextIp, err := types.AllocatePeerIP(network.Net, takenIps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
|
||||
}
|
||||
|
||||
return nextIp, nil
|
||||
return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
@@ -1000,7 +1006,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
}()
|
||||
|
||||
if isRequiresApproval {
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -1249,17 +1255,19 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
lock := mu.(*sync.Mutex)
|
||||
|
||||
if !lock.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
lock.Unlock()
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
lock := mu.(*sync.Mutex)
|
||||
|
||||
if !lock.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
lock.Unlock()
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}()
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1477,19 +1485,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
|
||||
return groupIDs, err
|
||||
}
|
||||
|
||||
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
|
||||
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingLabels := make(types.LookupMap)
|
||||
for _, label := range dnsLabels {
|
||||
existingLabels[label] = struct{}{}
|
||||
}
|
||||
return existingLabels, nil
|
||||
}
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
|
||||
|
||||
@@ -20,14 +20,14 @@ type Peer struct {
|
||||
// WireGuard public key
|
||||
Key string `gorm:"index"`
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations)
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string
|
||||
DNSLabel string // uniqueness index per accountID (check migrations)
|
||||
// Status peer's management connection status
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||
// The user ID that registered the peer
|
||||
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +21,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
@@ -480,7 +483,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -667,7 +670,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: testCase.role,
|
||||
@@ -737,7 +740,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Users[regularUser] = &types.User{
|
||||
Id: regularUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1267,7 +1270,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1342,7 +1345,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1391,7 +1394,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
name: "Absent setup key",
|
||||
existingSetupKeyID: "AAAAAAAA-38F5-4553-B31E-DD66C696CEBB",
|
||||
expectAddPeerError: true,
|
||||
expectedErrorMsgSubstring: "failed adding new peer: account not found",
|
||||
expectedErrorMsgSubstring: "failed to get setup key: setup key not found",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1456,6 +1459,10 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||
if engine == "sqlite" || engine == "" {
|
||||
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
@@ -1477,7 +1484,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1546,7 +1553,7 @@ func Test_LoginPeer(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
@@ -1761,7 +1768,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg) //
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2052,15 +2059,19 @@ func Test_DeletePeer(t *testing.T) {
|
||||
// account with an admin and a regular user
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account.Peers = map[string]*nbpeer.Peer{
|
||||
"peer1": {
|
||||
ID: "peer1",
|
||||
AccountID: accountID,
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
DNSLabel: "peer1.test",
|
||||
},
|
||||
"peer2": {
|
||||
ID: "peer2",
|
||||
AccountID: accountID,
|
||||
IP: net.IP{2, 2, 2, 2},
|
||||
DNSLabel: "peer2.test",
|
||||
},
|
||||
}
|
||||
account.Groups = map[string]*types.Group{
|
||||
@@ -2090,3 +2101,139 @@ func Test_DeletePeer(t *testing.T) {
|
||||
assert.NotContains(t, group.Peers, "peer1")
|
||||
|
||||
}
|
||||
|
||||
func Test_IsUniqueConstraintError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
engine types.Engine
|
||||
}{
|
||||
{
|
||||
name: "PostgreSQL uniqueness error",
|
||||
engine: types.PostgresStoreEngine,
|
||||
},
|
||||
{
|
||||
name: "MySQL uniqueness error",
|
||||
engine: types.MysqlStoreEngine,
|
||||
},
|
||||
{
|
||||
name: "SQLite uniqueness error",
|
||||
engine: types.SqliteStoreEngine,
|
||||
},
|
||||
}
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
ID: "test-peer-id",
|
||||
AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
DNSLabel: "test-peer-dns-label",
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(tt.engine))
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
||||
result := isUniqueConstraintError(err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AddPeer(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
t.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "300ms")
|
||||
t.Setenv("NB_PEER_UPDATE_BUFFER_INTERVAL", "300ms")
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatalf("error creating account: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", types.SetupKeyReusable, time.Hour, nil, 10000, userID, false, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 300
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
AccountID: accountID,
|
||||
Key: "key" + strconv.Itoa(i),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"},
|
||||
}
|
||||
|
||||
<-start
|
||||
|
||||
_, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer)
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
|
||||
|
||||
seenIP := make(map[string]bool)
|
||||
for _, p := range account.Peers {
|
||||
ipStr := p.IP.String()
|
||||
if seenIP[ipStr] {
|
||||
t.Fatalf("Duplicate IP found in account %s: %s", accountID, ipStr)
|
||||
}
|
||||
seenIP[ipStr] = true
|
||||
}
|
||||
|
||||
seenLabel := make(map[string]bool)
|
||||
for _, p := range account.Peers {
|
||||
if seenLabel[p.DNSLabel] {
|
||||
t.Fatalf("Duplicate Label found in account %s: %s", accountID, p.DNSLabel)
|
||||
}
|
||||
seenLabel[p.DNSLabel] = true
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes)
|
||||
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account.Users[admin.Id] = admin
|
||||
account.Users[user.Id] = user
|
||||
|
||||
|
||||
@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (store.Store, error) {
|
||||
@@ -1305,7 +1305,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -156,7 +156,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
||||
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migratePreAuto from a version that didn't support groups. Error: %v", err)
|
||||
// if the All group didn't exist we probably don't have routes to update
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -92,17 +92,20 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||
}
|
||||
|
||||
if err := migrate(ctx, db); err != nil {
|
||||
return nil, fmt.Errorf("migrate: %w", err)
|
||||
if err := migratePreAuto(ctx, db); err != nil {
|
||||
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.Network{},
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migrate: %w", err)
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
}
|
||||
if err := migratePostAuto(ctx, db); err != nil {
|
||||
return nil, fmt.Errorf("migratePostAuto: %w", err)
|
||||
}
|
||||
|
||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||
@@ -183,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
|
||||
generateAccountSQLTypes(account)
|
||||
|
||||
for _, group := range account.GroupsG {
|
||||
group.StoreGroupPeers()
|
||||
}
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
||||
if result.Error != nil {
|
||||
@@ -244,7 +251,7 @@ func generateAccountSQLTypes(account *types.Account) {
|
||||
|
||||
for id, group := range account.Groups {
|
||||
group.ID = id
|
||||
account.GroupsG = append(account.GroupsG, *group)
|
||||
account.GroupsG = append(account.GroupsG, group)
|
||||
}
|
||||
|
||||
for id, route := range account.Routes {
|
||||
@@ -452,19 +459,40 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength,
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.
|
||||
Clauses(
|
||||
clause.Locking{Strength: string(lockStrength)},
|
||||
clause.OnConflict{
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
},
|
||||
).
|
||||
Create(&groups)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||
for _, g := range groups {
|
||||
g.StoreGroupPeers()
|
||||
}
|
||||
return nil
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.Locking{Strength: string(lockStrength)},
|
||||
clause.OnConflict{
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
},
|
||||
).
|
||||
Create(&groups)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save groups to store")
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
if len(g.GroupPeers) == 0 {
|
||||
if err := tx.Where("group_id = ?", g.ID).Delete(&types.GroupPeer{}).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group peers for group %s: %s", g.ID, err)
|
||||
return status.Errorf(status.Internal, "failed to delete group peers")
|
||||
}
|
||||
} else {
|
||||
if err := tx.Model(&g).Association("GroupPeers").Replace(g.GroupPeers); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to save group peers: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||
@@ -643,7 +671,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
}
|
||||
|
||||
var groups []*types.Group
|
||||
result := tx.Find(&groups, accountIDCondition, accountID)
|
||||
result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@@ -652,6 +680,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
@@ -666,6 +698,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
likePattern := `%"ID":"` + resourceID + `"%`
|
||||
|
||||
result := tx.
|
||||
Preload(clause.Associations).
|
||||
Where("resources LIKE ?", likePattern).
|
||||
Find(&groups)
|
||||
|
||||
@@ -676,6 +709,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
@@ -735,8 +772,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
}()
|
||||
|
||||
var account types.Account
|
||||
result := s.db.Model(&account).
|
||||
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||
result := s.db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Info)}).Model(&account).
|
||||
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||
Preload("GroupsG.GroupPeers"). // have to be specifies as this is nester reference
|
||||
Preload(clause.Associations).
|
||||
First(&account, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
@@ -750,7 +788,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
||||
for i, policy := range account.Policies {
|
||||
var rules []*types.PolicyRule
|
||||
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
err := s.db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Info)}).Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "rule not found")
|
||||
}
|
||||
@@ -781,6 +819,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
|
||||
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
||||
for _, group := range account.GroupsG {
|
||||
group.LoadGroupPeers()
|
||||
account.Groups[group.ID] = group.Copy()
|
||||
}
|
||||
account.GroupsG = nil
|
||||
@@ -967,7 +1006,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
||||
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string, dnsLabel string) ([]string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
@@ -975,7 +1014,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
|
||||
var labels []string
|
||||
result := tx.Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ?", accountID).
|
||||
Where("account_id = ? AND dns_label LIKE ?", accountID, dnsLabel+"%").
|
||||
Pluck("dns_label", &labels)
|
||||
|
||||
if result.Error != nil {
|
||||
@@ -995,14 +1034,14 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var accountNetwork types.AccountNetwork
|
||||
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
accountNetwork := types.Network{}
|
||||
if err := tx.Where(accountIDCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
|
||||
}
|
||||
return accountNetwork.Network, nil
|
||||
return &accountNetwork, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
@@ -1184,7 +1223,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
|
||||
for _, account := range fileStore.GetAllAccounts(ctx) {
|
||||
_, err = account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
if err := account.AddAllGroup(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -1254,7 +1293,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewSetupKeyNotFoundError(key)
|
||||
return nil, status.Errorf(status.PreconditionFailed, "setup key not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
|
||||
@@ -1282,55 +1321,74 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
}
|
||||
|
||||
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
|
||||
var group types.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&group, "account_id = ? AND name = ?", accountID, "All")
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
var groupID string
|
||||
_ = s.db.Model(types.Group{}).
|
||||
Select("id").
|
||||
Where("account_id = ? AND name = ?", accountID, "All").
|
||||
Limit(1).
|
||||
Scan(&groupID)
|
||||
|
||||
if groupID == "" {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID)
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerID {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(&types.GroupPeer{
|
||||
GroupID: groupID,
|
||||
PeerID: peerID,
|
||||
}).Error
|
||||
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
|
||||
var group types.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
|
||||
First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewGroupNotFoundError(groupID)
|
||||
}
|
||||
|
||||
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||
// AddPeerToGroup adds a peer to a group
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, peerID string, groupID string) error {
|
||||
peer := &types.GroupPeer{
|
||||
GroupID: groupID,
|
||||
PeerID: peerID,
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerId {
|
||||
return nil
|
||||
}
|
||||
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(peer).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to add peer %s to group %s: %v", peerID, groupID, err)
|
||||
return status.Errorf(status.Internal, "failed to add peer to group")
|
||||
}
|
||||
|
||||
group.Peers = append(group.Peers, peerId)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
||||
// RemovePeerFromGroup removes a peer from a group
|
||||
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
|
||||
return status.Errorf(status.Internal, "failed to remove peer from group")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeerFromAllGroups removes a peer from all groups
|
||||
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
|
||||
return status.Errorf(status.Internal, "failed to remove peer from all groups")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1398,12 +1456,19 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
var groups []*types.Group
|
||||
query := tx.
|
||||
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
|
||||
Joins("JOIN group_peers ON group_peers.group_id = groups.id").
|
||||
Where("group_peers.peer_id = ?", peerId).
|
||||
Preload(clause.Associations).
|
||||
Find(&groups)
|
||||
|
||||
if query.Error != nil {
|
||||
return nil, query.Error
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
@@ -1452,7 +1517,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
|
||||
if err := s.db.Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||
}
|
||||
|
||||
@@ -1580,7 +1645,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength,
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
Model(&types.Network{}).Where(accountIDCondition, accountId).Update("serial", gorm.Expr("serial + 1"))
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||
@@ -1689,7 +1754,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
var group *types.Group
|
||||
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewGroupNotFoundError(groupID)
|
||||
@@ -1698,15 +1763,14 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store")
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var group types.Group
|
||||
|
||||
@@ -1714,16 +1778,14 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
// we may need to reconsider changing the types.
|
||||
query := tx.Preload(clause.Associations)
|
||||
|
||||
switch s.storeEngine {
|
||||
case types.PostgresStoreEngine:
|
||||
query = query.Order("json_array_length(peers::json) DESC")
|
||||
case types.MysqlStoreEngine:
|
||||
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
|
||||
default:
|
||||
query = query.Order("json_array_length(peers) DESC")
|
||||
}
|
||||
|
||||
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
|
||||
result := query.
|
||||
Model(&types.Group{}).
|
||||
Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id").
|
||||
Where("groups.account_id = ? AND groups.name = ?", accountID, groupName).
|
||||
Group("groups.id").
|
||||
Order("COUNT(group_peers.peer_id) DESC").
|
||||
Limit(1).
|
||||
First(&group)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewGroupNotFoundError(groupName)
|
||||
@@ -1731,6 +1793,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
@@ -1742,7 +1807,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
var groups []*types.Group
|
||||
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
|
||||
@@ -1750,6 +1815,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
groupsMap := make(map[string]*types.Group)
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
@@ -1758,17 +1824,36 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
// SaveGroup saves a group to the store.
|
||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
|
||||
if group == nil {
|
||||
return status.Errorf(status.InvalidArgument, "group is nil")
|
||||
}
|
||||
|
||||
group = group.Copy()
|
||||
group.StoreGroupPeers()
|
||||
|
||||
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
|
||||
if len(group.GroupPeers) == 0 {
|
||||
if err := s.db.Where("group_id = ?", group.ID).Delete(&types.GroupPeer{}).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group peers for group %s: %s", group.ID, err)
|
||||
return status.Errorf(status.Internal, "failed to delete group peers")
|
||||
}
|
||||
} else {
|
||||
if err := s.db.Model(&group).Association("GroupPeers").Replace(group.GroupPeers); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to save group peers: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup deletes a group from the database.
|
||||
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Select(clause.Associations).
|
||||
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
|
||||
@@ -1785,6 +1870,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
|
||||
// DeleteGroups deletes groups from the database.
|
||||
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
|
||||
Select(clause.Associations).
|
||||
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
||||
@@ -2546,6 +2632,27 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
|
||||
return &peer, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var peerID string
|
||||
result := tx.Model(&nbpeer.Peer{}).
|
||||
Select("id").
|
||||
// Where(" = ?", hostname).
|
||||
Where("account_id = ? AND dns_label = ?", accountID, hostname).
|
||||
Limit(1).
|
||||
Scan(&peerID)
|
||||
|
||||
if peerID == "" {
|
||||
return "", gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return peerID, result.Error
|
||||
}
|
||||
|
||||
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
|
||||
var count int64
|
||||
result := s.db.Model(&types.Account{}).
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -630,7 +631,7 @@ func TestMigrate(t *testing.T) {
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on empty db")
|
||||
|
||||
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
|
||||
@@ -685,10 +686,10 @@ func TestMigrate(t *testing.T) {
|
||||
err = store.(*SqlStore).db.Save(rt).Error
|
||||
require.NoError(t, err, "Failed to insert Gob data")
|
||||
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on gob populated db")
|
||||
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||
|
||||
err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
|
||||
@@ -704,10 +705,10 @@ func TestMigrate(t *testing.T) {
|
||||
err = store.(*SqlStore).db.Save(nRT).Error
|
||||
require.NoError(t, err, "Failed to insert json nil slice data")
|
||||
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on json nil slice populated db")
|
||||
|
||||
err = migrate(context.Background(), store.(*SqlStore).db)
|
||||
err = migratePreAuto(context.Background(), store.(*SqlStore).db)
|
||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||
|
||||
}
|
||||
@@ -950,6 +951,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1",
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||
@@ -961,8 +963,9 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
assert.Equal(t, []net.IP{ip1}, takenIPs)
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer2",
|
||||
ID: "peer1second",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1-1",
|
||||
IP: net.IP{2, 2, 2, 2},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||
@@ -972,49 +975,100 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
ip2 := net.IP{2, 2, 2, 2}.To16()
|
||||
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
|
||||
|
||||
}
|
||||
|
||||
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
peerHostname := "peer1"
|
||||
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{}, labels)
|
||||
|
||||
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{}, labels)
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1",
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||
require.NoError(t, err)
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1"}, labels)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test"}, labels)
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer1second",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1-1",
|
||||
IP: net.IP{2, 2, 2, 2},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer2",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer2.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||
require.NoError(t, err)
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
|
||||
expected := []string{"peer1", "peer1-1"}
|
||||
sort.Strings(expected)
|
||||
sort.Strings(labels)
|
||||
assert.Equal(t, expected, labels)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AddPeerWithSameDnsLabel(t *testing.T) {
|
||||
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer1second",
|
||||
AccountID: existingAccountID,
|
||||
DNSLabel: "peer1.domain.test",
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AddPeerWithSameIP(t *testing.T) {
|
||||
runTestForAllEngines(t, "../testdata/extended-store.sql", func(t *testing.T, store Store) {
|
||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
AccountID: existingAccountID,
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer1second",
|
||||
AccountID: existingAccountID,
|
||||
IP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccountNetwork(t *testing.T) {
|
||||
@@ -1286,10 +1340,12 @@ func TestSqlStore_SaveGroup(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group := &types.Group{
|
||||
ID: "group-id",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
ID: "group-id",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
}
|
||||
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
require.NoError(t, err)
|
||||
@@ -1308,16 +1364,19 @@ func TestSqlStore_SaveGroups(t *testing.T) {
|
||||
|
||||
groups := []*types.Group{
|
||||
{
|
||||
ID: "group-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
ID: "group-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
},
|
||||
{
|
||||
ID: "group-2",
|
||||
AccountID: accountID,
|
||||
Issued: "integration",
|
||||
Peers: []string{"peer3", "peer4"},
|
||||
Resources: []types.Resource{},
|
||||
},
|
||||
}
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
|
||||
@@ -2005,7 +2064,7 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
network := types.NewNetwork(accountID)
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[nbroute.ID]*nbroute.Route)
|
||||
@@ -2044,7 +2103,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
},
|
||||
}
|
||||
|
||||
if err := acc.AddAllGroup(); err != nil {
|
||||
if err := acc.AddAllGroup(false); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
@@ -2452,7 +2511,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) {
|
||||
require.NoError(t, err, "failed to get group")
|
||||
require.Len(t, group.Peers, 0, "group should have 0 peers")
|
||||
|
||||
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID)
|
||||
err = store.AddPeerToGroup(context.Background(), peerID, groupID)
|
||||
require.NoError(t, err, "failed to add peer to group")
|
||||
|
||||
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
@@ -2483,7 +2542,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) {
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
|
||||
require.NoError(t, err, "failed to add peer to account")
|
||||
|
||||
err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID)
|
||||
err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
|
||||
require.NoError(t, err, "failed to add peer to all group")
|
||||
|
||||
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
@@ -2569,7 +2628,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) {
|
||||
assert.Len(t, groups, 1)
|
||||
assert.Equal(t, groups[0].Name, "All")
|
||||
|
||||
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h")
|
||||
err = store.AddPeerToGroup(context.Background(), peerID, "cfefqs706sqkneg59g4h")
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID)
|
||||
|
||||
@@ -117,9 +117,11 @@ type Store interface {
|
||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, peerId string, groupID string) error
|
||||
RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error
|
||||
RemovePeerFromAllGroups(ctx context.Context, peerID string) error
|
||||
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
||||
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
|
||||
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
|
||||
@@ -193,6 +195,7 @@ type Store interface {
|
||||
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
|
||||
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
|
||||
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
||||
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -234,9 +237,9 @@ func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) type
|
||||
if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) {
|
||||
log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile)
|
||||
|
||||
// Attempt to migrate from JSON store to SQLite
|
||||
// Attempt to migratePreAuto from JSON store to SQLite
|
||||
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to migratePreAuto filestore to SQLite: %v", err)
|
||||
kind = types.FileStoreEngine
|
||||
}
|
||||
}
|
||||
@@ -280,9 +283,9 @@ func checkFileStoreEngine(kind types.Engine, dataDir string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrate migrates the SQLite database to the latest schema
|
||||
func migrate(ctx context.Context, db *gorm.DB) error {
|
||||
migrations := getMigrations(ctx)
|
||||
// migratePreAuto migrates the SQLite database to the latest schema
|
||||
func migratePreAuto(ctx context.Context, db *gorm.DB) error {
|
||||
migrations := getMigrationsPreAuto(ctx)
|
||||
|
||||
for _, m := range migrations {
|
||||
if err := m(db); err != nil {
|
||||
@@ -293,7 +296,7 @@ func migrate(ctx context.Context, db *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMigrations(ctx context.Context) []migrationFunc {
|
||||
func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
|
||||
@@ -329,6 +332,47 @@ func getMigrations(ctx context.Context) []migrationFunc {
|
||||
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
|
||||
},
|
||||
}
|
||||
} // migratePostAuto migrates the SQLite database to the latest schema
|
||||
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
|
||||
migrations := getMigrationsPostAuto(ctx)
|
||||
|
||||
for _, m := range migrations {
|
||||
if err := m(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_ip", "account_id", "ip")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(id, value string) any {
|
||||
return &types.GroupPeer{
|
||||
GroupID: id,
|
||||
PeerID: value,
|
||||
}
|
||||
})
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateEmbeddedToTable[types.Account, migration.LegacyAccountNetwork, types.Network](ctx, db, "id", func(obj migration.LegacyAccountNetwork) *types.Network {
|
||||
return &types.Network{
|
||||
AccountID: obj.AccountID,
|
||||
Identifier: obj.Identifier,
|
||||
Net: obj.Net,
|
||||
Serial: obj.Serial,
|
||||
Dns: obj.Dns,
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
|
||||
@@ -391,7 +435,7 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
|
||||
|
||||
_, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
if err := account.AddAllGroup(false); err != nil {
|
||||
return err
|
||||
}
|
||||
shouldSave = true
|
||||
@@ -577,7 +621,7 @@ func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
|
||||
|
||||
sqliteStoreAccounts := len(store.GetAllAccounts(ctx))
|
||||
if fsStoreAccounts != sqliteStoreAccounts {
|
||||
return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d",
|
||||
return fmt.Errorf("failed to migratePreAuto accounts from file to sqlite. Expected accounts: %d, got: %d",
|
||||
fsStoreAccounts, sqliteStoreAccounts)
|
||||
}
|
||||
|
||||
|
||||
2
management/server/testdata/store.sql
vendored
2
management/server/testdata/store.sql
vendored
@@ -52,4 +52,4 @@ INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','D
|
||||
INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0);
|
||||
INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1');
|
||||
INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');
|
||||
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','192.168.0.0','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
|
||||
INSERT INTO peers VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','','','"192.168.0.0"','','','','','','','','','','','','','','','','','test','test','2023-01-01 00:00:00+00:00',0,0,0,'a23efe53-63fb-11ec-90d6-0242ac120003','',0,0,'2023-01-01 00:00:00+00:00','2023-01-01 00:00:00+00:00',0,'','','',0);
|
||||
|
||||
@@ -30,7 +30,7 @@ INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62
|
||||
INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
@@ -67,13 +67,13 @@ type Account struct {
|
||||
IsDomainPrimaryAccount bool
|
||||
SetupKeys map[string]*SetupKey `gorm:"-"`
|
||||
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
Network *Network `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Peers map[string]*nbpeer.Peer `gorm:"-"`
|
||||
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Users map[string]*User `gorm:"-"`
|
||||
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Groups map[string]*Group `gorm:"-"`
|
||||
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
||||
Routes map[route.ID]*route.Route `gorm:"-"`
|
||||
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
@@ -1546,7 +1546,7 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st
|
||||
}
|
||||
|
||||
// AddAllGroup to account object if it doesn't exist
|
||||
func (a *Account) AddAllGroup() error {
|
||||
func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
if len(a.Groups) == 0 {
|
||||
allGroup := &Group{
|
||||
ID: xid.New().String(),
|
||||
@@ -1558,6 +1558,10 @@ func (a *Account) AddAllGroup() error {
|
||||
}
|
||||
a.Groups = map[string]*Group{allGroup.ID: allGroup}
|
||||
|
||||
if disableDefaultPolicy {
|
||||
return nil
|
||||
}
|
||||
|
||||
id := xid.New().String()
|
||||
|
||||
defaultPolicy := &Policy{
|
||||
|
||||
@@ -53,6 +53,9 @@ type Config struct {
|
||||
StoreConfig StoreConfig
|
||||
|
||||
ReverseProxy ReverseProxy
|
||||
|
||||
// disable default all-to-all policy
|
||||
DisableDefaultPolicy bool
|
||||
}
|
||||
|
||||
// GetAuthAudiences returns the audience from the http config and device authorization flow config
|
||||
|
||||
@@ -26,7 +26,8 @@ type Group struct {
|
||||
Issued string
|
||||
|
||||
// Peers list of the group
|
||||
Peers []string `gorm:"serializer:json"`
|
||||
Peers []string `gorm:"-"`
|
||||
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// Resources contains a list of resources in that group
|
||||
Resources []Resource `gorm:"serializer:json"`
|
||||
@@ -34,6 +35,29 @@ type Group struct {
|
||||
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||
}
|
||||
|
||||
type GroupPeer struct {
|
||||
GroupID string `gorm:"primaryKey"`
|
||||
PeerID string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
func (g *Group) LoadGroupPeers() {
|
||||
g.Peers = make([]string, len(g.GroupPeers))
|
||||
for i, peer := range g.GroupPeers {
|
||||
g.Peers[i] = peer.PeerID
|
||||
}
|
||||
g.GroupPeers = []GroupPeer{}
|
||||
}
|
||||
func (g *Group) StoreGroupPeers() {
|
||||
g.GroupPeers = make([]GroupPeer, len(g.Peers))
|
||||
for i, peer := range g.Peers {
|
||||
g.GroupPeers[i] = GroupPeer{
|
||||
GroupID: g.ID,
|
||||
PeerID: peer,
|
||||
}
|
||||
}
|
||||
g.Peers = []string{}
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the group
|
||||
func (g *Group) EventMeta() map[string]any {
|
||||
return map[string]any{"name": g.Name}
|
||||
@@ -46,13 +70,16 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an
|
||||
func (g *Group) Copy() *Group {
|
||||
group := &Group{
|
||||
ID: g.ID,
|
||||
AccountID: g.AccountID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: make([]string, len(g.Peers)),
|
||||
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
|
||||
Resources: make([]Resource, len(g.Resources)),
|
||||
IntegrationReference: g.IntegrationReference,
|
||||
}
|
||||
copy(group.Peers, g.Peers)
|
||||
copy(group.GroupPeers, g.GroupPeers)
|
||||
copy(group.Resources, g.Resources)
|
||||
return group
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -106,7 +107,8 @@ func ipToBytes(ip net.IP) []byte {
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
Identifier string `json:"id"`
|
||||
AccountID string `gorm:"primaryKey"`
|
||||
Identifier string `gorm:"index"`
|
||||
Net net.IPNet `gorm:"serializer:json"`
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
@@ -116,9 +118,13 @@ type Network struct {
|
||||
Mu sync.Mutex `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
func (*Network) TableName() string {
|
||||
return "account_networks"
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
|
||||
func NewNetwork() *Network {
|
||||
func NewNetwork(accountID string) *Network {
|
||||
|
||||
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
|
||||
sub, _ := n.Subnet(SubnetSize)
|
||||
@@ -128,6 +134,7 @@ func NewNetwork() *Network {
|
||||
intn := r.Intn(len(sub))
|
||||
|
||||
return &Network{
|
||||
AccountID: accountID,
|
||||
Identifier: xid.New().String(),
|
||||
Net: sub[intn].IPNet,
|
||||
Dns: "",
|
||||
@@ -150,6 +157,7 @@ func (n *Network) CurrentSerial() uint64 {
|
||||
|
||||
func (n *Network) Copy() *Network {
|
||||
return &Network{
|
||||
AccountID: n.AccountID,
|
||||
Identifier: n.Identifier,
|
||||
Net: n.Net,
|
||||
Dns: n.Dns,
|
||||
@@ -161,24 +169,65 @@ func (n *Network) Copy() *Network {
|
||||
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
|
||||
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
|
||||
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
||||
takenIPMap := make(map[string]struct{})
|
||||
takenIPMap[ipNet.IP.String()] = struct{}{}
|
||||
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
|
||||
totalIPs := uint32(1 << SubnetSize)
|
||||
|
||||
taken := make(map[uint32]struct{}, len(takenIps)+1)
|
||||
taken[baseIP] = struct{}{} // reserve network IP
|
||||
taken[baseIP+totalIPs-1] = struct{}{} // reserve broadcast IP
|
||||
|
||||
for _, ip := range takenIps {
|
||||
takenIPMap[ip.String()] = struct{}{}
|
||||
taken[ipToUint32(ip)] = struct{}{}
|
||||
}
|
||||
|
||||
ips, _ := generateIPs(&ipNet, takenIPMap)
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
maxAttempts := (int(totalIPs) - len(taken)) / 100
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
offset := uint32(rng.Intn(int(totalIPs-2))) + 1
|
||||
candidate := baseIP + offset
|
||||
if _, exists := taken[candidate]; !exists {
|
||||
return uint32ToIP(candidate), nil
|
||||
}
|
||||
}
|
||||
|
||||
// pick a random IP
|
||||
s := rand.NewSource(time.Now().Unix())
|
||||
r := rand.New(s)
|
||||
intn := r.Intn(len(ips))
|
||||
for offset := uint32(1); offset < totalIPs-1; offset++ {
|
||||
candidate := baseIP + offset
|
||||
if _, exists := taken[candidate]; !exists {
|
||||
return uint32ToIP(candidate), nil
|
||||
}
|
||||
}
|
||||
|
||||
return ips[intn], nil
|
||||
return nil, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", ipNet.String())
|
||||
}
|
||||
|
||||
func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) {
|
||||
baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask))
|
||||
|
||||
ones, bits := ipNet.Mask.Size()
|
||||
hostBits := bits - ones
|
||||
|
||||
totalIPs := uint32(1 << hostBits)
|
||||
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
offset := uint32(rng.Intn(int(totalIPs-2))) + 1
|
||||
|
||||
candidate := baseIP + offset
|
||||
return uint32ToIP(candidate), nil
|
||||
}
|
||||
|
||||
func ipToUint32(ip net.IP) uint32 {
|
||||
ip = ip.To4()
|
||||
if len(ip) < 4 {
|
||||
return 0
|
||||
}
|
||||
return binary.BigEndian.Uint32(ip)
|
||||
}
|
||||
|
||||
func uint32ToIP(n uint32) net.IP {
|
||||
ip := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(ip, n)
|
||||
return ip
|
||||
}
|
||||
|
||||
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestNewNetwork(t *testing.T) {
|
||||
network := NewNetwork()
|
||||
network := NewNetwork("accountID")
|
||||
|
||||
// generated net should be a subnet of a larger 100.64.0.0/10 net
|
||||
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}}
|
||||
|
||||
@@ -35,7 +35,7 @@ type SetupKey struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
Key string
|
||||
KeySecret string
|
||||
KeySecret string `gorm:"index"`
|
||||
Name string
|
||||
Type SetupKeyType
|
||||
CreatedAt time.Time
|
||||
|
||||
@@ -56,7 +56,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = s.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -103,7 +103,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: false,
|
||||
@@ -131,7 +131,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: true,
|
||||
@@ -163,7 +163,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -188,7 +188,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -213,7 +213,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
@@ -256,7 +256,7 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -296,7 +296,7 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -406,7 +406,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -453,7 +453,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -501,7 +501,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -532,7 +532,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -639,7 +639,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockServiceUserID] = tt.serviceUser
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
@@ -678,7 +678,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -705,7 +705,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -792,7 +792,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -952,7 +952,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -988,7 +988,7 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users["normal_user1"] = types.NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = types.NewRegularUser("normal_user2")
|
||||
|
||||
@@ -1030,7 +1030,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
externalUser := &types.User{
|
||||
Id: "externalUser",
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1098,7 +1098,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1132,7 +1132,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1499,7 +1499,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "")
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", false)
|
||||
targetId := "user2"
|
||||
account1.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
@@ -1508,7 +1508,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "")
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", false)
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account2))
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
@@ -1535,7 +1535,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "")
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", false)
|
||||
account1.Settings.RegularUsersViewBlocked = false
|
||||
account1.Users["blocked-user"] = &types.User{
|
||||
Id: "blocked-user",
|
||||
@@ -1557,7 +1557,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, store.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "")
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", false)
|
||||
account2.Users["settings-blocked-user"] = &types.User{
|
||||
Id: "settings-blocked-user",
|
||||
Role: types.UserRoleUser,
|
||||
|
||||
@@ -130,7 +130,7 @@ repo_gpgcheck=1
|
||||
EOF
|
||||
}
|
||||
|
||||
add_aur_repo() {
|
||||
install_aur_package() {
|
||||
INSTALL_PKGS="git base-devel go"
|
||||
REMOVE_PKGS=""
|
||||
|
||||
@@ -154,8 +154,10 @@ add_aur_repo() {
|
||||
cd netbird-ui && makepkg -sri --noconfirm
|
||||
fi
|
||||
|
||||
# Clean up the installed packages
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
if [ -n "$REMOVE_PKGS" ]; then
|
||||
# Clean up the installed packages
|
||||
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
|
||||
fi
|
||||
}
|
||||
|
||||
prepare_tun_module() {
|
||||
@@ -277,7 +279,9 @@ install_netbird() {
|
||||
;;
|
||||
pacman)
|
||||
${SUDO} pacman -Syy
|
||||
add_aur_repo
|
||||
install_aur_package
|
||||
# in-line with the docs at https://wiki.archlinux.org/title/Netbird
|
||||
${SUDO} systemctl enable --now netbird@main.service
|
||||
;;
|
||||
pkg)
|
||||
# Check if the package is already installed
|
||||
@@ -494,4 +498,4 @@ case "$UPDATE_FLAG" in
|
||||
;;
|
||||
*)
|
||||
install_netbird
|
||||
esac
|
||||
esac
|
||||
|
||||
Reference in New Issue
Block a user