[management] fix extend call and move config to types (#3575)

This PR fixes configuration inconsistencies and updates the store engine type usage throughout the management code. Key changes include:
- Replacing outdated server.Config references with types.Config and updating related flag variables (e.g. types.MgmtConfigPath).
- Converting engine constants (SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine) to use types.Engine for consistent type–safety.
- Adjusting various test and migration code paths to correctly reference the new configuration and engine types.
This commit is contained in:
Maycon Santos
2025-03-27 13:04:50 +01:00
committed by GitHub
parent fceb3ca392
commit a4f04f5570
25 changed files with 237 additions and 169 deletions

View File

@@ -2793,15 +2793,15 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
})
}
type TB interface {
Cleanup(func())
Helper()
TempDir() string
Errorf(format string, args ...interface{})
Fatalf(format string, args ...interface{})
}
//type TB interface {
// Cleanup(func())
// Helper()
// TempDir() string
// Errorf(format string, args ...interface{})
// Fatalf(format string, args ...interface{})
//}
func createManager(t TB) (*DefaultAccountManager, error) {
func createManager(t testing.TB) (*DefaultAccountManager, error) {
t.Helper()
store, err := createStore(t)
@@ -2836,7 +2836,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
return manager, nil
}
func createStore(t TB) (store.Store, error) {
func createStore(t testing.TB) (store.Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/account"
@@ -40,7 +41,7 @@ type GRPCServer struct {
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager
config *Config
config *types.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
@@ -51,7 +52,7 @@ type GRPCServer struct {
// NewServer creates a new Management server
func NewServer(
ctx context.Context,
config *Config,
config *types.Config,
accountManager account.Manager,
settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager,
@@ -530,24 +531,24 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil
}
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
func ToResponseProto(configProto types.Protocol) proto.HostConfig_Protocol {
switch configProto {
case UDP:
case types.UDP:
return proto.HostConfig_UDP
case DTLS:
case types.DTLS:
return proto.HostConfig_DTLS
case HTTP:
case types.HTTP:
return proto.HostConfig_HTTP
case HTTPS:
case types.HTTPS:
return proto.HostConfig_HTTPS
case TCP:
case types.TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
}
@@ -610,8 +611,6 @@ func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token,
Relay: relayCfg,
}
integrationsConfig.ExtendNetBirdConfig(nbConfig, extraSettings)
return nbConfig
}
@@ -626,10 +625,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dns
}
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
response := &proto.SyncResponse{
NetbirdConfig: toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
@@ -638,6 +636,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings)
response.NetbirdConfig = nbConfig
response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
@@ -754,7 +756,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) {
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(types.NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
}

View File

@@ -93,21 +93,21 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir()
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &TURNConfig{
TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Turns: []*types.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &Host{
Signal: &types.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},
@@ -330,7 +330,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testCases := []struct {
name string
inputFlow *DeviceAuthorizationFlow
inputFlow *types.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
@@ -345,9 +345,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &DeviceAuthorizationFlow{
inputFlow: &types.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: ProviderConfig{
ProviderConfig: types.ProviderConfig{
ClientID: "test",
},
},
@@ -356,9 +356,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
},
{
name: "Testing Full Device Flow Config",
inputFlow: &DeviceAuthorizationFlow{
inputFlow: &types.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: ProviderConfig{
ProviderConfig: types.ProviderConfig{
ClientID: "test",
},
},
@@ -379,7 +379,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{
wgKey: testingServerKey,
config: &Config{
config: &types.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
@@ -410,7 +410,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
func startManagementForTest(t *testing.T, testFile string, config *types.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
@@ -506,21 +506,21 @@ func testSyncStatusRace(t *testing.T) {
t.Skip()
dir := t.TempDir()
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &TURNConfig{
TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Turns: []*types.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &Host{
Signal: &types.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},
@@ -678,21 +678,21 @@ func Test_LoginPerformance(t *testing.T) {
t.Helper()
dir := t.TempDir()
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &TURNConfig{
TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Turns: []*types.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &Host{
Signal: &types.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},

View File

@@ -58,7 +58,7 @@ func setupTest(t *testing.T) *testSuite {
t.Fatalf("failed to create temp directory: %v", err)
}
config := &server.Config{}
config := &types.Config{}
_, err = util.ReadJson("testdata/management.json", config)
if err != nil {
t.Fatalf("failed to read management.json: %v", err)
@@ -156,7 +156,7 @@ func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClie
func startServer(
t *testing.T,
config *server.Config,
config *types.Config,
dataDir string,
testFile string,
) (*grpc.Server, net.Listener) {
@@ -300,6 +300,10 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
Protocol: mgmtProto.HostConfig_UDP,
}
expectedRelayHost := &mgmtProto.RelayConfig{
Urls: []string{"rel://test.com:3535"},
}
assert.NotNil(t, resp.NetbirdConfig)
assert.Equal(t, resp.NetbirdConfig.Signal, expectedSignalConfig)
assert.Contains(t, resp.NetbirdConfig.Stuns, expectedStunsConfig)
@@ -307,6 +311,8 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
actualTURN := resp.NetbirdConfig.Turns[0]
assert.Greater(t, len(actualTURN.User), 0)
assert.Equal(t, actualTURN.HostConfig, expectedTRUNHost)
assert.Equal(t, len(resp.NetbirdConfig.Relay.Urls), 1)
assert.Equal(t, resp.NetbirdConfig.Relay.Urls, expectedRelayHost.Urls)
assert.Equal(t, len(resp.NetworkMap.OfflinePeers), 0)
}

View File

@@ -15,7 +15,6 @@ import (
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbversion "github.com/netbirdio/netbird/version"
)
@@ -49,7 +48,7 @@ type properties map[string]interface{}
// DataSource metric data source
type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() store.Engine
GetStoreEngine() types.Engine
}
// ConnManager peer connection manager that holds state for current active connections

View File

@@ -10,7 +10,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@@ -205,8 +204,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
}
// GetStoreEngine returns FileStoreEngine
func (mockDatasource) GetStoreEngine() store.Engine {
return store.FileStoreEngine
func (mockDatasource) GetStoreEngine() types.Engine {
return types.FileStoreEngine
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
@@ -304,7 +303,7 @@ func TestGenerateProperties(t *testing.T) {
t.Errorf("expected 2 user_peers, got %d", properties["user_peers"])
}
if properties["store_engine"] != store.FileStoreEngine {
if properties["store_engine"] != types.FileStoreEngine {
t.Errorf("expected JsonFile, got %s", properties["store_engine"])
}

View File

@@ -1213,7 +1213,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return
}
update := toSyncResponse(ctx, &Config{}, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer)
}
@@ -1282,7 +1282,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return
}
update := toSyncResponse(ctx, &Config{}, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}

View File

@@ -723,7 +723,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
}
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) {
b.Helper()
manager, err := createManager(b)
@@ -998,6 +998,53 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
}
func TestUpdateAccountPeers(t *testing.T) {
testCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 1},
{"Medium", 500, 1},
{"Large", 1000, 1},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
if err != nil {
t.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
t.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
manager.UpdateAccountPeers(ctx, account.Id)
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
}
})
}
}
func TestToSyncResponse(t *testing.T) {
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil {
@@ -1008,16 +1055,16 @@ func TestToSyncResponse(t *testing.T) {
t.Fatal(err)
}
config := &Config{
Signal: &Host{
config := &types.Config{
Signal: &types.Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
TURNConfig: &TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
Stuns: []*types.Host{{URI: "stun.uri", Proto: types.UDP}},
TURNConfig: &types.TURNConfig{
Turns: []*types.Host{{URI: "turn.uri", Proto: types.UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{

View File

@@ -260,6 +260,6 @@ func (s *FileStore) Close(ctx context.Context) error {
}
// GetStoreEngine returns FileStoreEngine
func (s *FileStore) GetStoreEngine() Engine {
return FileStoreEngine
func (s *FileStore) GetStoreEngine() types.Engine {
return types.FileStoreEngine
}

View File

@@ -55,7 +55,7 @@ type SqlStore struct {
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
storeEngine Engine
storeEngine types.Engine
}
type installation struct {
@@ -66,7 +66,7 @@ type installation struct {
type migrationFunc func(*gorm.DB) error
// NewSqlStore creates a new SqlStore instance.
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
sql, err := db.DB()
if err != nil {
return nil, err
@@ -77,7 +77,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
conns = runtime.NumCPU()
}
if storeEngine == SqliteStoreEngine {
if storeEngine == types.SqliteStoreEngine {
if err == nil {
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
}
@@ -105,7 +105,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
}
func GetKeyQueryCondition(s *SqlStore) string {
if s.storeEngine == MysqlStoreEngine {
if s.storeEngine == types.MysqlStoreEngine {
return mysqlKeyQueryCondition
}
return keyQueryCondition
@@ -970,7 +970,7 @@ func (s *SqlStore) Close(_ context.Context) error {
}
// GetStoreEngine returns underlying store engine
func (s *SqlStore) GetStoreEngine() Engine {
func (s *SqlStore) GetStoreEngine() types.Engine {
return s.storeEngine
}
@@ -988,7 +988,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
return nil, err
}
return NewSqlStore(ctx, db, SqliteStoreEngine, metrics)
return NewSqlStore(ctx, db, types.SqliteStoreEngine, metrics)
}
// NewPostgresqlStore creates a new Postgres store.
@@ -998,7 +998,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
return nil, err
}
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics)
}
// NewMysqlStore creates a new MySQL store.
@@ -1008,7 +1008,7 @@ func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics
return nil, err
}
return NewSqlStore(ctx, db, MysqlStoreEngine, metrics)
return NewSqlStore(ctx, db, types.MysqlStoreEngine, metrics)
}
func getGormConfig() *gorm.Config {
@@ -1517,9 +1517,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
switch s.storeEngine {
case PostgresStoreEngine:
case types.PostgresStoreEngine:
query = query.Order("json_array_length(peers::json) DESC")
case MysqlStoreEngine:
case types.MysqlStoreEngine:
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
default:
query = query.Order("json_array_length(peers) DESC")

View File

@@ -297,7 +297,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -628,7 +628,7 @@ func TestMigrate(t *testing.T) {
}
// TODO: figure out why this fails on postgres
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
@@ -737,7 +737,7 @@ func TestPostgresql_NewStore(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -752,7 +752,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -825,7 +825,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -901,7 +901,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -921,7 +921,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@@ -935,7 +935,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
}
func TestSqlite_GetTakenIPs(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
defer cleanup()
if err != nil {
@@ -980,7 +980,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
if err != nil {
return
@@ -1022,7 +1022,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@@ -1045,7 +1045,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@@ -1070,7 +1070,7 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@@ -1106,7 +1106,7 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
}
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@@ -1211,7 +1211,7 @@ func TestSqlite_GetGroupByName(t *testing.T) {
}
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@@ -1227,7 +1227,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
}
func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)

View File

@@ -164,7 +164,7 @@ type Store interface {
Close(ctx context.Context) error
// GetStoreEngine should return Engine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() Engine
GetStoreEngine() types.Engine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
@@ -187,44 +187,37 @@ type Store interface {
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
}
type Engine string
const (
FileStoreEngine Engine = "jsonfile"
SqliteStoreEngine Engine = "sqlite"
PostgresStoreEngine Engine = "postgres"
MysqlStoreEngine Engine = "mysql"
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)
var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine}
var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
func getStoreEngineFromEnv() Engine {
func getStoreEngineFromEnv() types.Engine {
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
if !ok {
return ""
}
value := Engine(strings.ToLower(kind))
value := types.Engine(strings.ToLower(kind))
if slices.Contains(supportedEngines, value) {
return value
}
return SqliteStoreEngine
return types.SqliteStoreEngine
}
// getStoreEngine determines the store engine to use.
// If no engine is specified, it attempts to retrieve it from the environment.
// If still not specified, it defaults to using SQLite.
// Additionally, it handles the migration from a JSON store file to SQLite if applicable.
func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) types.Engine {
if kind == "" {
kind = getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
kind = types.SqliteStoreEngine
// Migrate if it is the first run with a JSON file existing and no SQLite file present
jsonStoreFile := filepath.Join(dataDir, storeFileName)
@@ -236,7 +229,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
// Attempt to migrate from JSON store to SQLite
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
kind = FileStoreEngine
kind = types.FileStoreEngine
}
}
}
@@ -246,7 +239,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
}
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
func NewStore(ctx context.Context, kind types.Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
kind = getStoreEngine(ctx, dataDir, kind)
if err := checkFileStoreEngine(kind, dataDir); err != nil {
@@ -254,13 +247,13 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
}
switch kind {
case SqliteStoreEngine:
case types.SqliteStoreEngine:
log.WithContext(ctx).Info("using SQLite store engine")
return NewSqliteStore(ctx, dataDir, metrics)
case PostgresStoreEngine:
case types.PostgresStoreEngine:
log.WithContext(ctx).Info("using Postgres store engine")
return newPostgresStore(ctx, metrics)
case MysqlStoreEngine:
case types.MysqlStoreEngine:
log.WithContext(ctx).Info("using MySQL store engine")
return newMysqlStore(ctx, metrics)
default:
@@ -268,12 +261,12 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
}
}
func checkFileStoreEngine(kind Engine, dataDir string) error {
if kind == FileStoreEngine {
func checkFileStoreEngine(kind types.Engine, dataDir string) error {
if kind == types.FileStoreEngine {
storeFile := filepath.Join(dataDir, storeFileName)
if util.FileExists(storeFile) {
return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", FileStoreEngine)
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", types.FileStoreEngine)
}
}
return nil
@@ -326,7 +319,7 @@ func getMigrations(ctx context.Context) []migrationFunc {
func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
kind := getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
kind = types.SqliteStoreEngine
}
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
@@ -348,7 +341,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
}
}
store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil)
store, err := NewSqlStore(ctx, db, types.SqliteStoreEngine, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
}
@@ -394,13 +387,13 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
return nil
}
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) (Store, func(), error) {
var cleanup func()
var err error
switch kind {
case PostgresStoreEngine:
case types.PostgresStoreEngine:
store, cleanup, err = newReusedPostgresStore(ctx, store, kind)
case MysqlStoreEngine:
case types.MysqlStoreEngine:
store, cleanup, err = newReusedMysqlStore(ctx, store, kind)
default:
cleanup = func() {
@@ -419,7 +412,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store
return store, closeConnection, nil
}
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
var err error
_, err = testutil.CreatePostgresTestContainer()
@@ -451,7 +444,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (
return store, cleanup, nil
}
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
var err error
_, err = testutil.CreateMysqlTestContainer()
@@ -483,7 +476,7 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*Sq
return store, cleanup, nil
}
func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) {
func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
@@ -493,9 +486,9 @@ func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), err
var err error
cleanup := func() {
switch engine {
case PostgresStoreEngine:
case types.PostgresStoreEngine:
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
case MysqlStoreEngine:
case types.MysqlStoreEngine:
// err = killMySQLConnections(dsn, dbName)
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
}

View File

@@ -20,6 +20,11 @@
"Secret": "c29tZV9wYXNzd29yZA==",
"TimeBasedCredentials": true
},
"Relay":{
"Addresses":["rel://test.com:3535"],
"CredentialsTTL":"2h",
"Secret":"netbird"
},
"Signal": {
"Proto": "http",
"URI": "signal.netbird.io:10000",

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
@@ -32,8 +33,8 @@ type SecretsManager interface {
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
turnCfg *TURNConfig
relayCfg *Relay
turnCfg *types.TURNConfig
relayCfg *types.Relay
turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager
@@ -44,7 +45,7 @@ type TimeBasedAuthSecretsManager struct {
type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,
@@ -221,7 +222,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
}
}
m.extendNetbirdConfig(ctx, accountID, update)
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
@@ -245,17 +246,17 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
},
}
m.extendNetbirdConfig(ctx, accountID, update)
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, accountID string, update *proto.SyncResponse) {
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
}
integrationsConfig.ExtendNetBirdConfig(update.NetbirdConfig, extraSettings)
integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings)
}

View File

@@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/util"
)
var TurnTestHost = &Host{
Proto: UDP,
var TurnTestHost = &types.Host{
Proto: types.UDP,
URI: "turn:turn.netbird.io:77777",
Username: "username",
Password: "",
@@ -31,7 +31,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
rc := &Relay{
rc := &types.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@@ -41,10 +41,10 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)
@@ -81,7 +81,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &Relay{
rc := &types.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@@ -92,10 +92,10 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)
@@ -184,7 +184,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
peersManager := NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &Relay{
rc := &types.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@@ -194,10 +194,10 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)

View File

@@ -1,10 +1,9 @@
package server
package types
import (
"net/netip"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
@@ -30,6 +29,8 @@ const (
DefaultDeviceAuthFlowScope string = "openid"
)
var MgmtConfigPath string
// Config of the Management service
type Config struct {
Stuns []*Host
@@ -76,6 +77,7 @@ type TURNConfig struct {
Turns []*Host
}
// Relay configuration type
type Relay struct {
Addresses []string
CredentialsTTL util.Duration
@@ -156,7 +158,7 @@ type ProviderConfig struct {
// StoreConfig contains Store configuration
type StoreConfig struct {
Engine store.Engine
Engine Engine
}
// ReverseProxy contains reverse proxy configuration in front of management.

View File

@@ -0,0 +1,10 @@
package types
type Engine string
const (
PostgresStoreEngine Engine = "postgres"
FileStoreEngine Engine = "jsonfile"
SqliteStoreEngine Engine = "sqlite"
MysqlStoreEngine Engine = "mysql"
)