From 3ce3ccc39a604af4b2134694c3efcfd6fbc569a1 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Sat, 21 May 2022 17:21:39 +0400 Subject: [PATCH] Add rules for ACL (#306) Add rules HTTP endpoint for frontend - CRUD operations. Add Default rule - allow all. Send network map to peers based on rules. --- client/cmd/testutil.go | 5 +- client/internal/engine_test.go | 5 +- management/client/client_test.go | 88 ++++---- management/cmd/management.go | 12 +- management/server/account.go | 82 ++++++- management/server/account_test.go | 15 +- management/server/file_store.go | 153 +++++++++++-- management/server/group.go | 8 + management/server/grpcserver.go | 31 +-- management/server/http/handler/groups.go | 74 +++++- management/server/http/handler/rules.go | 211 ++++++++++++++++++ management/server/http/handler/rules_test.go | 211 ++++++++++++++++++ management/server/http/server.go | 7 + management/server/management_proto_test.go | 28 +-- management/server/management_test.go | 33 +-- management/server/mock_server/account_mock.go | 35 ++- management/server/peer.go | 128 ++++++++--- management/server/peer_test.go | 150 ++++++++++++- management/server/rule.go | 107 +++++++++ management/server/store.go | 3 + management/server/user.go | 1 + 21 files changed, 1197 insertions(+), 190 deletions(-) create mode 100644 management/server/http/handler/rules.go create mode 100644 management/server/http/handler/rules_test.go create mode 100644 management/server/rule.go diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 9a766b7d1..f4e01782b 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -68,7 +68,10 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste } peersUpdateManager := mgmt.NewPeersUpdateManager() - accountManager := mgmt.NewManager(store, peersUpdateManager, nil) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil) + if err != nil { + t.Fatal(err) + } turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager) if err != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 434224bd8..fe43f9f79 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -455,7 +455,10 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } peersUpdateManager := server.NewPeersUpdateManager() - accountManager := server.NewManager(store, peersUpdateManager, nil) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil) + if err != nil { + return nil, err + } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager) if err != nil { diff --git a/management/client/client_test.go b/management/client/client_test.go index 2e83d8988..8591d7eaf 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -28,7 +28,6 @@ import ( const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" func startManagement(t *testing.T) (*grpc.Server, net.Listener) { - level, _ := log.ParseLevel("debug") log.SetLevel(level) @@ -56,7 +55,10 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } peersUpdateManager := mgmt.NewPeersUpdateManager() - accountManager := mgmt.NewManager(store, peersUpdateManager, nil) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil) + if err != nil { + t.Fatal(err) + } turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager) if err != nil { @@ -256,6 +258,7 @@ func TestClient_Sync(t *testing.T) { } if len(resp.GetRemotePeers()) != 1 { t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers())) + return } if resp.GetRemotePeersIsEmpty() == true { t.Error("expecting RemotePeers property to be false, got true") @@ -295,37 +298,36 @@ func Test_SystemMetaDataFromClient(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - mgmtMockServer.LoginFunc = - func(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { - peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey()) - if err != nil { - log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey) - return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey) - } - - loginReq := &proto.LoginRequest{} - err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq) - if err != nil { - log.Fatal(err) - } - - actualMeta = loginReq.GetMeta() - actualValidKey = loginReq.GetSetupKey() - wg.Done() - - loginResp := &proto.LoginResponse{} - encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp) - if err != nil { - return nil, err - } - - return &mgmtProto.EncryptedMessage{ - WgPubKey: serverKey.PublicKey().String(), - Body: encryptedResp, - Version: 0, - }, nil + mgmtMockServer.LoginFunc = func(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey()) + if err != nil { + log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey) + return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey) } + loginReq := &proto.LoginRequest{} + err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq) + if err != nil { + log.Fatal(err) + } + + actualMeta = loginReq.GetMeta() + actualValidKey = loginReq.GetSetupKey() + wg.Done() + + loginResp := &proto.LoginResponse{} + encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp) + if err != nil { + return nil, err + } + + return &mgmtProto.EncryptedMessage{ + WgPubKey: serverKey.PublicKey().String(), + Body: encryptedResp, + Version: 0, + }, nil + } + info := system.GetInfo() _, err = testClient.Register(*key, ValidKey, "", info) if err != nil { @@ -370,21 +372,19 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) { ProviderConfig: &proto.ProviderConfig{ClientID: "client"}, } - mgmtMockServer.GetDeviceAuthorizationFlowFunc = - func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) { - - encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) - if err != nil { - return nil, err - } - - return &mgmtProto.EncryptedMessage{ - WgPubKey: serverKey.PublicKey().String(), - Body: encryptedResp, - Version: 0, - }, nil + mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) { + encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) + if err != nil { + return nil, err } + return &mgmtProto.EncryptedMessage{ + WgPubKey: serverKey.PublicKey().String(), + Body: encryptedResp, + Version: 0, + }, nil + } + flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey) if err != nil { t.Error("error while retrieving device auth flow information") diff --git a/management/cmd/management.go b/management/cmd/management.go index 7178b65b7..c870cf648 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -108,20 +108,23 @@ var ( } } - accountManager := server.NewManager(store, peersUpdateManager, idpManager) + accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager) + if err != nil { + log.Fatalln("failed build default manager: ", err) + } var opts []grpc.ServerOption var httpServer *http.Server if config.HttpConfig.LetsEncryptDomain != "" { - //automatically generate a new certificate with Let's Encrypt + // automatically generate a new certificate with Let's Encrypt certManager := encryption.CreateCertManager(config.Datadir, config.HttpConfig.LetsEncryptDomain) transportCredentials := credentials.NewTLS(certManager.TLSConfig()) opts = append(opts, grpc.Creds(transportCredentials)) httpServer = http.NewHttpsServer(config.HttpConfig, certManager, accountManager) } else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" { - //use provided certificate + // use provided certificate tlsConfig, err := loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey) if err != nil { log.Fatal("cannot load TLS credentials: ", err) @@ -130,7 +133,7 @@ var ( opts = append(opts, grpc.Creds(transportCredentials)) httpServer = http.NewHttpsServerWithTLSConfig(config.HttpConfig, tlsConfig, accountManager) } else { - //start server without SSL + // start server without SSL httpServer = http.NewHttpServer(config.HttpConfig, accountManager) } @@ -309,5 +312,4 @@ func init() { mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") rootCmd.MarkFlagRequired("config") //nolint - } diff --git a/management/server/account.go b/management/server/account.go index 1e211b919..ba6248175 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -24,7 +24,12 @@ const ( type AccountManager interface { GetOrCreateAccountByUser(userId, domain string) (*Account, error) GetAccountByUser(userId string) (*Account, error) - AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn *util.Duration) (*SetupKey, error) + AddSetupKey( + accountId string, + keyName string, + keyType SetupKeyType, + expiresIn *util.Duration, + ) (*SetupKey, error) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error) GetAccountById(accountId string) (*Account, error) @@ -47,6 +52,10 @@ type AccountManager interface { GroupAddPeer(accountId, groupID, peerKey string) error GroupDeletePeer(accountId, groupID, peerKey string) error GroupListPeers(accountId, groupID string) ([]*Peer, error) + GetRule(accountId, ruleID string) (*Rule, error) + SaveRule(accountID string, rule *Rule) error + DeleteRule(accountId, ruleID string) error + ListRules(accountId string) ([]*Rule, error) } type DefaultAccountManager struct { @@ -70,6 +79,7 @@ type Account struct { Peers map[string]*Peer Users map[string]*User Groups map[string]*Group + Rules map[string]*Rule } type UserInfo struct { @@ -101,6 +111,16 @@ func (a *Account) Copy() *Account { setupKeys[id] = key.Copy() } + groups := map[string]*Group{} + for id, group := range a.Groups { + groups[id] = group.Copy() + } + + rules := map[string]*Rule{} + for id, rule := range a.Rules { + rules[id] = rule.Copy() + } + return &Account{ Id: a.Id, CreatedBy: a.CreatedBy, @@ -108,17 +128,43 @@ func (a *Account) Copy() *Account { Network: a.Network.Copy(), Peers: peers, Users: users, + Groups: groups, + Rules: rules, } } -// NewManager creates a new DefaultAccountManager with a provided Store -func NewManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager) *DefaultAccountManager { - return &DefaultAccountManager{ +func (a *Account) GetGroupAll() (*Group, error) { + for _, g := range a.Groups { + if g.Name == "All" { + return g, nil + } + } + return nil, fmt.Errorf("no group ALL found") +} + +// BuildManager creates a new DefaultAccountManager with a provided Store +func BuildManager( + store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, +) (*DefaultAccountManager, error) { + dam := &DefaultAccountManager{ Store: store, mux: sync.Mutex{}, peersUpdateManager: peersUpdateManager, idpManager: idpManager, } + + // if account has not default account + // we build 'all' group and add all peers into it + // also we create default rule with source an destination + // groups 'all' + for _, account := range store.GetAllAccounts() { + dam.addAllGroup(account) + if err := store.SaveAccount(account); err != nil { + return nil, err + } + } + + return dam, nil } // AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account @@ -223,7 +269,9 @@ func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, err // GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and // user id doesn't have an account associated with it, one account is created -func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetAccountByUserOrAccountId( + userId, accountId, domain string, +) (*Account, error) { if accountId != "" { return am.GetAccountById(accountId) } else if userId != "" { @@ -490,6 +538,8 @@ func (am *DefaultAccountManager) AddAccount(accountId, userId, domain string) (* func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) (*Account, error) { account := newAccountWithId(accountId, userId, domain) + am.addAllGroup(account) + err := am.Store.SaveAccount(account) if err != nil { return nil, status.Errorf(codes.Internal, "failed creating account") @@ -498,6 +548,28 @@ func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) return account, nil } +// addAllGroup to account object it it doesn't exists +func (am *DefaultAccountManager) addAllGroup(account *Account) { + if len(account.Groups) == 0 { + allGroup := &Group{ + ID: xid.New().String(), + Name: "All", + } + for _, peer := range account.Peers { + allGroup.Peers = append(allGroup.Peers, peer.Key) + } + account.Groups = map[string]*Group{allGroup.ID: allGroup} + + defaultRule := &Rule{ + ID: xid.New().String(), + Name: "Default", + Source: []string{allGroup.ID}, + Destination: []string{allGroup.ID}, + } + account.Rules = map[string]*Rule{defaultRule.ID: defaultRule} + } +} + // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id func newAccountWithId(accountId, userId, domain string) *Account { log.Debugf("creating new account") diff --git a/management/server/account_test.go b/management/server/account_test.go index e4da9d5ac..d63391a37 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -37,7 +37,6 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { - type initUserParams jwtclaims.AuthorizationClaims type test struct { @@ -165,7 +164,6 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { } for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} { t.Run(testCase.name, func(t *testing.T) { - manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -346,7 +344,6 @@ func TestAccountManager_AccountExists(t *testing.T) { if !*exists { t.Errorf("expected account to exist after creation, got false") } - } func TestAccountManager_GetAccount(t *testing.T) { @@ -363,7 +360,7 @@ func TestAccountManager_GetAccount(t *testing.T) { t.Fatal(err) } - //AddAccount has been already tested so we can assume it is correct and compare results + // AddAccount has been already tested so we can assume it is correct and compare results getAccount, err := manager.GetAccountById(expectedId) if err != nil { t.Fatal(err) @@ -385,7 +382,6 @@ func TestAccountManager_GetAccount(t *testing.T) { t.Errorf("expected account to have setup key %s, not found", key.Key) } } - } func TestAccountManager_AddPeer(t *testing.T) { @@ -400,7 +396,7 @@ func TestAccountManager_AddPeer(t *testing.T) { t.Fatal(err) } - serial := account.Network.CurrentSerial() //should be 0 + serial := account.Network.CurrentSerial() // should be 0 var setupKey *SetupKey for _, key := range account.SetupKeys { @@ -457,7 +453,6 @@ func TestAccountManager_AddPeer(t *testing.T) { if account.Network.CurrentSerial() != 1 { t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial()) } - } func TestAccountManager_AddPeerWithUserID(t *testing.T) { @@ -474,7 +469,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { t.Fatal(err) } - serial := account.Network.CurrentSerial() //should be 0 + serial := account.Network.CurrentSerial() // should be 0 if account.Network.Serial != 0 { t.Errorf("expecting account network to have an initial Serial=0") @@ -521,7 +516,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { if account.Network.CurrentSerial() != 1 { t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial()) } - } func TestAccountManager_DeletePeer(t *testing.T) { @@ -573,7 +567,6 @@ func TestAccountManager_DeletePeer(t *testing.T) { if account.Network.CurrentSerial() != 2 { t.Errorf("expecting Network Serial=%d to be incremented and be equal to 2 after adding and deleteing a peer", account.Network.CurrentSerial()) } - } func TestGetUsersFromAccount(t *testing.T) { @@ -614,7 +607,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { if err != nil { return nil, err } - return NewManager(store, NewPeersUpdateManager(), nil), nil + return BuildManager(store, NewPeersUpdateManager(), nil) } func createStore(t *testing.T) (Store, error) { diff --git a/management/server/file_store.go b/management/server/file_store.go index 4f89a32ca..4648098b1 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1,6 +1,7 @@ package server import ( + "fmt" "os" "path/filepath" "strings" @@ -18,10 +19,12 @@ const storeFileName = "store.json" // FileStore represents an account storage backed by a file persisted to disk type FileStore struct { Accounts map[string]*Account - SetupKeyId2AccountId map[string]string `json:"-"` - PeerKeyId2AccountId map[string]string `json:"-"` - UserId2AccountId map[string]string `json:"-"` - PrivateDomain2AccountId map[string]string `json:"-"` + SetupKeyId2AccountId map[string]string `json:"-"` + PeerKeyId2AccountId map[string]string `json:"-"` + UserId2AccountId map[string]string `json:"-"` + PrivateDomain2AccountId map[string]string `json:"-"` + PeerKeyId2SrcRulesId map[string]map[string]struct{} `json:"-"` + PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"` // mutex to synchronise Store read/write operations mux sync.Mutex `json:"-"` @@ -47,6 +50,8 @@ func restore(file string) (*FileStore, error) { PeerKeyId2AccountId: make(map[string]string), UserId2AccountId: make(map[string]string), PrivateDomain2AccountId: make(map[string]string), + PeerKeyId2SrcRulesId: make(map[string]map[string]struct{}), + PeerKeyId2DstRulesId: make(map[string]map[string]struct{}), storeFile: file, } @@ -69,10 +74,39 @@ func restore(file string) (*FileStore, error) { store.PeerKeyId2AccountId = make(map[string]string) store.UserId2AccountId = make(map[string]string) store.PrivateDomain2AccountId = make(map[string]string) + store.PeerKeyId2SrcRulesId = map[string]map[string]struct{}{} + store.PeerKeyId2DstRulesId = map[string]map[string]struct{}{} + for accountId, account := range store.Accounts { for setupKeyId := range account.SetupKeys { store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId } + for _, rule := range account.Rules { + for _, groupID := range rule.Source { + if group, ok := account.Groups[groupID]; ok { + for _, peerID := range group.Peers { + rules := store.PeerKeyId2SrcRulesId[peerID] + if rules == nil { + rules = map[string]struct{}{} + store.PeerKeyId2SrcRulesId[peerID] = rules + } + rules[rule.ID] = struct{}{} + } + } + } + for _, groupID := range rule.Destination { + if group, ok := account.Groups[groupID]; ok { + for _, peerID := range group.Peers { + rules := store.PeerKeyId2DstRulesId[peerID] + if rules == nil { + rules = map[string]struct{}{} + store.PeerKeyId2DstRulesId[peerID] = rules + } + rules[rule.ID] = struct{}{} + } + } + } + } for _, peer := range account.Peers { store.PeerKeyId2AccountId[peer.Key] = accountId } @@ -82,7 +116,8 @@ func restore(file string) (*FileStore, error) { for _, user := range account.Users { store.UserId2AccountId[user.Id] = accountId } - if account.Domain != "" && account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount { + if account.Domain != "" && account.DomainCategory == PrivateCategory && + account.IsDomainPrimaryAccount { store.PrivateDomain2AccountId[account.Domain] = accountId } } @@ -106,6 +141,24 @@ func (s *FileStore) SavePeer(accountId string, peer *Peer) error { return err } + // if it is new peer, add it to default 'All' group + allGroup, err := account.GetGroupAll() + if err != nil { + return err + } + + ind := -1 + for i, pid := range allGroup.Peers { + if pid == peer.Key { + ind = i + break + } + } + + if ind < 0 { + allGroup.Peers = append(allGroup.Peers, peer.Key) + } + account.Peers[peer.Key] = peer return s.persist(s.storeFile) } @@ -176,6 +229,29 @@ func (s *FileStore) SaveAccount(account *Account) error { s.PeerKeyId2AccountId[peer.Key] = account.Id } + for _, rule := range account.Rules { + for _, gid := range rule.Source { + for _, pid := range account.Groups[gid].Peers { + rules := s.PeerKeyId2SrcRulesId[pid] + if rules == nil { + rules = map[string]struct{}{} + s.PeerKeyId2SrcRulesId[pid] = rules + } + rules[rule.ID] = struct{}{} + } + } + for _, gid := range rule.Destination { + for _, pid := range account.Groups[gid].Peers { + rules := s.PeerKeyId2DstRulesId[pid] + if rules == nil { + rules = map[string]struct{}{} + s.PeerKeyId2DstRulesId[pid] = rules + } + rules[rule.ID] = struct{}{} + } + } + } + for _, user := range account.Users { s.UserId2AccountId[user.Id] = account.Id } @@ -190,7 +266,10 @@ func (s *FileStore) SaveAccount(account *Account) error { func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)] if !accountIdFound { - return nil, status.Errorf(codes.NotFound, "provided domain is not registered or is not private") + return nil, status.Errorf( + codes.NotFound, + "provided domain is not registered or is not private", + ) } account, err := s.GetAccount(accountId) @@ -232,6 +311,14 @@ func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) { return peers, nil } +func (s *FileStore) GetAllAccounts() (all []*Account) { + for _, a := range s.Accounts { + all = append(all, a) + } + + return all +} + func (s *FileStore) GetAccount(accountId string) (*Account, error) { account, accountFound := s.Accounts[accountId] if !accountFound { @@ -265,18 +352,52 @@ func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) { return s.GetAccount(accountId) } -func (s *FileStore) GetGroup(groupID string) (*Group, error) { - return nil, nil +func (s *FileStore) GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.GetAccount(accountId) + if err != nil { + return nil, err + } + + ruleIDs, ok := s.PeerKeyId2SrcRulesId[peerKey] + if !ok { + return nil, fmt.Errorf("no rules for peer: %v", ruleIDs) + } + + rules := []*Rule{} + for id := range ruleIDs { + rule, ok := account.Rules[id] + if ok { + rules = append(rules, rule) + } + } + + return rules, nil } -func (s *FileStore) SaveGroup(group *Group) error { - return nil -} +func (s *FileStore) GetPeerDstRules(accountId, peerKey string) ([]*Rule, error) { + s.mux.Lock() + defer s.mux.Unlock() -func (s *FileStore) DeleteGroup(groupID string) error { - return nil -} + account, err := s.GetAccount(accountId) + if err != nil { + return nil, err + } -func (s *FileStore) ListGroups() ([]*Group, error) { - return nil, nil + ruleIDs, ok := s.PeerKeyId2DstRulesId[peerKey] + if !ok { + return nil, fmt.Errorf("no rules for peer: %v", ruleIDs) + } + + rules := []*Rule{} + for id := range ruleIDs { + rule, ok := account.Rules[id] + if ok { + rules = append(rules, rule) + } + } + + return rules, nil } diff --git a/management/server/group.go b/management/server/group.go index 807acbd4a..e71f33eb3 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -17,6 +17,14 @@ type Group struct { Peers []string } +func (g *Group) Copy() *Group { + return &Group{ + ID: g.ID, + Name: g.Name, + Peers: g.Peers[:], + } +} + // GetGroup object of the peers func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { am.mux.Lock() diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index d6c00d2cf..424599c50 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,11 +3,12 @@ package server import ( "context" "fmt" - "github.com/netbirdio/netbird/management/server/http/middleware" - "github.com/netbirdio/netbird/management/server/jwtclaims" "strings" "time" + "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/golang/protobuf/ptypes/timestamp" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" @@ -64,7 +65,6 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager } func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { - // todo introduce something more meaningful with the key expiration/rotation now := time.Now().Add(24 * time.Hour) secs := int64(now.Second()) @@ -77,10 +77,9 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser }, nil } -//Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and +// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { - log.Debugf("Sync request from peer %s", req.WgPubKey) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) @@ -155,7 +154,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S } func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) { - var ( reqSetupKey string userId string @@ -209,7 +207,7 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists") } - //todo move to DefaultAccountManager the code below + // todo move to DefaultAccountManager the code below networkMap, err := s.accountManager.GetNetworkMap(peer.Key) if err != nil { return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err) @@ -240,7 +238,6 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case of the successful registration login is also successful func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { - log.Debugf("Login request from peer %s", req.WgPubKey) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) @@ -252,18 +249,18 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto peer, err := s.accountManager.GetPeer(peerKey.String()) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound { - //peer doesn't exist -> check if setup key was provided + // peer doesn't exist -> check if setup key was provided loginReq := &proto.LoginRequest{} err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, loginReq) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid request message") } if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" { - //absent setup key -> permission denied + // absent setup key -> permission denied return nil, status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided", peerKey.String()) } - //setup key or jwt is present -> try normal registration flow + // setup key or jwt is present -> try normal registration flow peer, err = s.registerPeer(peerKey, loginReq) if err != nil { return nil, err @@ -303,13 +300,12 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol { case TCP: return proto.HostConfig_TCP default: - //mbragin: todo something better? + // mbragin: todo something better? panic(fmt.Errorf("unexpected config protocol type %v", configProto)) } } func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig { - var stuns []*proto.HostConfig for _, stun := range config.Stuns { stuns = append(stuns, &proto.HostConfig{ @@ -350,26 +346,23 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot func toPeerConfig(peer *Peer) *proto.PeerConfig { return &proto.PeerConfig{ - Address: peer.IP.String() + "/24", //todo make it explicit + Address: peer.IP.String() + "/24", // todo make it explicit } } func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig { - remotePeers := []*proto.RemotePeerConfig{} for _, rPeer := range peers { remotePeers = append(remotePeers, &proto.RemotePeerConfig{ WgPubKey: rPeer.Key, - AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, //todo /32 + AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, // todo /32 }) } return remotePeers - } func toSyncResponse(config *Config, peer *Peer, peers []*Peer, turnCredentials *TURNCredentials, serial uint64) *proto.SyncResponse { - wtConfig := toWiretrusteeConfig(config, turnCredentials) pConfig := toPeerConfig(peer) @@ -397,7 +390,6 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error { - networkMap, err := s.accountManager.GetNetworkMap(peer.Key) if err != nil { log.Warnf("error getting a list of peers for a peer %s", peer.Key) @@ -436,7 +428,6 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.Mana // This is used for initiating an Oauth 2 device authorization grant flow // which will be used by our clients to Login func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { - peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey) diff --git a/management/server/http/handler/groups.go b/management/server/http/handler/groups.go index 4bab4e7a8..cf179a522 100644 --- a/management/server/http/handler/groups.go +++ b/management/server/http/handler/groups.go @@ -13,20 +13,33 @@ import ( log "github.com/sirupsen/logrus" ) -// Groups is a handler that returns groups of the account -type Groups struct { - accountManager server.AccountManager - authAudience string - jwtExtractor jwtclaims.ClaimsExtractor -} - // GroupResponse is a response sent to the client type GroupResponse struct { + ID string + Name string + Peers []GroupPeerResponse `json:",omitempty"` +} + +// GroupPeerResponse is a response sent to the client +type GroupPeerResponse struct { + Key string + Name string +} + +// GroupRequest to create or update group +type GroupRequest struct { ID string Name string Peers []string } +// Groups is a handler that returns groups of the account +type Groups struct { + jwtExtractor jwtclaims.ClaimsExtractor + accountManager server.AccountManager + authAudience string +} + func NewGroups(accountManager server.AccountManager, authAudience string) *Groups { return &Groups{ accountManager: accountManager, @@ -44,7 +57,12 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { return } - writeJSONObject(w, account.Groups) + var groups []*GroupResponse + for _, g := range account.Groups { + groups = append(groups, toGroupResponse(account, g)) + } + + writeJSONObject(w, groups) } func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Request) { @@ -54,7 +72,7 @@ func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Reque return } - var req server.Group + var req GroupRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -64,13 +82,19 @@ func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Reque req.ID = xid.New().String() } - if err := h.accountManager.SaveGroup(account.Id, &req); err != nil { + group := server.Group{ + ID: req.ID, + Name: req.Name, + Peers: req.Peers, + } + + if err := h.accountManager.SaveGroup(account.Id, &group); err != nil { log.Errorf("failed updating group %s under account %s %v", req.ID, account.Id, err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - writeJSONObject(w, &req) + writeJSONObject(w, toGroupResponse(account, &group)) } func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { @@ -117,7 +141,7 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { return } - writeJSONObject(w, group) + writeJSONObject(w, toGroupResponse(account, group)) default: http.Error(w, "", http.StatusNotFound) } @@ -133,3 +157,29 @@ func (h *Groups) getGroupAccount(r *http.Request) (*server.Account, error) { return account, nil } + +func toGroupResponse(account *server.Account, group *server.Group) *GroupResponse { + cache := make(map[string]GroupPeerResponse) + gr := GroupResponse{ + ID: group.ID, + Name: group.Name, + } + + for _, pid := range group.Peers { + peerResp, ok := cache[pid] + if !ok { + peer, ok := account.Peers[pid] + if !ok { + continue + } + peerResp = GroupPeerResponse{ + Key: peer.Key, + Name: peer.Name, + } + cache[pid] = peerResp + } + gr.Peers = append(gr.Peers, peerResp) + } + + return &gr +} diff --git a/management/server/http/handler/rules.go b/management/server/http/handler/rules.go new file mode 100644 index 000000000..d51ea8ca4 --- /dev/null +++ b/management/server/http/handler/rules.go @@ -0,0 +1,211 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/rs/xid" + + "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" +) + +const FlowBidirectString = "bidirect" + +// RuleResponse is a response sent to the client +type RuleResponse struct { + ID string + Name string + Source []RuleGroupResponse + Destination []RuleGroupResponse + Flow string +} + +// RuleGroupResponse is a response sent to the client +type RuleGroupResponse struct { + ID string + Name string + PeersCount int +} + +// RuleRequest to create or update rule +type RuleRequest struct { + ID string + Name string + Source []string + Destination []string + Flow string +} + +// Rules is a handler that returns rules of the account +type Rules struct { + jwtExtractor jwtclaims.ClaimsExtractor + accountManager server.AccountManager + authAudience string +} + +func NewRules(accountManager server.AccountManager, authAudience string) *Rules { + return &Rules{ + accountManager: accountManager, + authAudience: authAudience, + jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + } +} + +// GetAllRulesHandler list for the account +func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { + account, err := h.getRuleAccount(r) + if err != nil { + log.Error(err) + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + var rules []*RuleResponse + for _, r := range account.Rules { + rules = append(rules, toRuleResponse(account, r)) + } + + writeJSONObject(w, rules) +} + +func (h *Rules) CreateOrUpdateRuleHandler(w http.ResponseWriter, r *http.Request) { + account, err := h.getRuleAccount(r) + if err != nil { + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + var req RuleRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if r.Method == http.MethodPost { + req.ID = xid.New().String() + } + + rule := server.Rule{ + ID: req.ID, + Name: req.Name, + Source: req.Source, + Destination: req.Destination, + } + + switch req.Flow { + case FlowBidirectString: + rule.Flow = server.TrafficFlowBidirect + default: + http.Error(w, "unknown flow type", http.StatusBadRequest) + return + } + + if err := h.accountManager.SaveRule(account.Id, &rule); err != nil { + log.Errorf("failed updating rule %s under account %s %v", req.ID, account.Id, err) + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + writeJSONObject(w, &req) +} + +func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { + account, err := h.getRuleAccount(r) + if err != nil { + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + aID := account.Id + + rID := mux.Vars(r)["id"] + if len(rID) == 0 { + http.Error(w, "invalid rule ID", http.StatusBadRequest) + return + } + + if err := h.accountManager.DeleteRule(aID, rID); err != nil { + log.Errorf("failed delete rule %s under account %s %v", rID, aID, err) + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + writeJSONObject(w, "") +} + +func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { + account, err := h.getRuleAccount(r) + if err != nil { + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + switch r.Method { + case http.MethodGet: + ruleID := mux.Vars(r)["id"] + if len(ruleID) == 0 { + http.Error(w, "invalid rule ID", http.StatusBadRequest) + return + } + + rule, err := h.accountManager.GetRule(account.Id, ruleID) + if err != nil { + http.Error(w, "rule not found", http.StatusNotFound) + return + } + + writeJSONObject(w, toRuleResponse(account, rule)) + default: + http.Error(w, "", http.StatusNotFound) + } +} + +func (h *Rules) getRuleAccount(r *http.Request) (*server.Account, error) { + jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + + account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims) + if err != nil { + return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) + } + + return account, nil +} + +func toRuleResponse(account *server.Account, rule *server.Rule) *RuleResponse { + gr := RuleResponse{ + ID: rule.ID, + Name: rule.Name, + } + + switch rule.Flow { + case server.TrafficFlowBidirect: + gr.Flow = FlowBidirectString + default: + gr.Flow = "unknown" + } + + for _, gid := range rule.Source { + if group, ok := account.Groups[gid]; ok { + gr.Source = append(gr.Source, RuleGroupResponse{ + ID: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + }) + } + } + + for _, gid := range rule.Destination { + if group, ok := account.Groups[gid]; ok { + gr.Destination = append(gr.Destination, RuleGroupResponse{ + ID: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + }) + } + } + + return &gr +} diff --git a/management/server/http/handler/rules_test.go b/management/server/http/handler/rules_test.go new file mode 100644 index 000000000..958beb4c6 --- /dev/null +++ b/management/server/http/handler/rules_test.go @@ -0,0 +1,211 @@ +package handler + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server/jwtclaims" + + "github.com/magiconair/properties/assert" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/mock_server" +) + +func initRulesTestData(rules ...*server.Rule) *Rules { + return &Rules{ + accountManager: &mock_server.MockAccountManager{ + SaveRuleFunc: func(_ string, rule *server.Rule) error { + if !strings.HasPrefix(rule.ID, "id-") { + rule.ID = "id-was-set" + } + return nil + }, + GetRuleFunc: func(_, ruleID string) (*server.Rule, error) { + if ruleID != "idoftherule" { + return nil, fmt.Errorf("not found") + } + return &server.Rule{ + ID: "idoftherule", + Name: "Rule", + Source: []string{"idofsrcrule"}, + Destination: []string{"idofdestrule"}, + Flow: server.TrafficFlowBidirect, + }, nil + }, + GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + return &server.Account{ + Id: claims.AccountId, + Domain: "hotmail.com", + }, nil + }, + }, + authAudience: "", + jwtExtractor: jwtclaims.ClaimsExtractor{ + ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + return jwtclaims.AuthorizationClaims{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + } + }, + }, + } +} + +func TestRulesGetRule(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestType string + requestPath string + requestBody io.Reader + }{ + { + name: "GetRule OK", + expectedBody: true, + requestType: http.MethodGet, + requestPath: "/api/rules/idoftherule", + expectedStatus: http.StatusOK, + }, + { + name: "GetRule not found", + requestType: http.MethodGet, + requestPath: "/api/rules/notexists", + expectedStatus: http.StatusNotFound, + }, + } + + rule := &server.Rule{ + ID: "idoftherule", + Name: "Rule", + } + + p := initRulesTestData(rule) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/rules/{id}", p.GetRuleHandler).Methods("GET") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tc.expectedStatus) + return + } + + if !tc.expectedBody { + return + } + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + var got RuleResponse + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, got.ID, rule.ID) + assert.Equal(t, got.Name, rule.Name) + }) + } +} + +func TestRulesSaveRule(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + expectedRule *server.Rule + requestType string + requestPath string + requestBody io.Reader + }{ + { + name: "SaveRule POST OK", + requestType: http.MethodPost, + requestPath: "/api/rules", + requestBody: bytes.NewBuffer( + []byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRule: &server.Rule{ + ID: "id-was-set", + Name: "Default POSTed Rule", + Flow: server.TrafficFlowBidirect, + }, + }, + { + name: "SaveRule PUT OK", + requestType: http.MethodPut, + requestPath: "/api/rules", + requestBody: bytes.NewBuffer( + []byte(`{"ID":"id-existed","Name":"Default POSTed Rule","Flow":"bidirect"}`)), + expectedStatus: http.StatusOK, + expectedRule: &server.Rule{ + ID: "id-existed", + Name: "Default POSTed Rule", + Flow: server.TrafficFlowBidirect, + }, + }, + } + + p := initRulesTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/rules", p.CreateOrUpdateRuleHandler).Methods("PUT", "POST") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v, content: %s", + status, tc.expectedStatus, string(content)) + return + } + + if !tc.expectedBody { + return + } + + got := &RuleRequest{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + if tc.requestType != http.MethodPost { + assert.Equal(t, got.ID, tc.expectedRule.ID) + } + assert.Equal(t, got.Name, tc.expectedRule.Name) + assert.Equal(t, got.Flow, "bidirect") + }) + } +} diff --git a/management/server/http/server.go b/management/server/http/server.go index 70e8c392a..f8c9b6d33 100644 --- a/management/server/http/server.go +++ b/management/server/http/server.go @@ -96,6 +96,7 @@ func (s *Server) Start() error { r.Use(jwtMiddleware.Handler, corsMiddleware.Handler) groupsHandler := handler.NewGroups(s.accountManager, s.config.AuthAudience) + rulesHandler := handler.NewRules(s.accountManager, s.config.AuthAudience) peersHandler := handler.NewPeers(s.accountManager, s.config.AuthAudience) keysHandler := handler.NewSetupKeysHandler(s.accountManager, s.config.AuthAudience) r.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS") @@ -112,6 +113,12 @@ func (s *Server) Start() error { r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey). Methods("GET", "PUT", "DELETE", "OPTIONS") + r.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS") + r.HandleFunc("/api/rules", rulesHandler.CreateOrUpdateRuleHandler). + Methods("POST", "PUT", "OPTIONS") + r.HandleFunc("/api/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS") + r.HandleFunc("/api/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS") + r.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS") r.HandleFunc("/api/groups", groupsHandler.CreateOrUpdateGroupHandler). Methods("POST", "PUT", "OPTIONS") diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 03361aad1..e2850a665 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -3,6 +3,13 @@ package server import ( "context" "fmt" + "net" + "os" + "path/filepath" + "runtime" + "testing" + "time" + "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/util" @@ -11,12 +18,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - "net" - "os" - "path/filepath" - "runtime" - "testing" - "time" ) var ( @@ -39,8 +40,7 @@ const ( // registerPeers registers peersNum peers on the management service and returns their Wireguard keys func registerPeers(peersNum int, client mgmtProto.ManagementServiceClient) ([]*wgtypes.Key, error) { - - var peers = []*wgtypes.Key{} + peers := []*wgtypes.Key{} for i := 0; i < peersNum; i++ { key, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -60,7 +60,6 @@ func registerPeers(peersNum int, client mgmtProto.ManagementServiceClient) ([]*w // getServerKey gets Management Service Wireguard public key func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error) { - keyResp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) if err != nil { return nil, err @@ -75,7 +74,6 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error } func Test_SyncProtocol(t *testing.T) { - dir := t.TempDir() err := util.CopyFileContents("testdata/store.json", filepath.Join(dir, "store.json")) if err != nil { @@ -263,7 +261,6 @@ func Test_SyncProtocol(t *testing.T) { } func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServiceClient) (*mgmtProto.LoginResponse, error) { - serverKey, err := getServerKey(client) if err != nil { return nil, err @@ -298,11 +295,9 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ } return loginResp, nil - } func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { - testingServerKey, err := wgtypes.GeneratePrivateKey() if err != nil { t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err) @@ -362,7 +357,6 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - mgmtServer := &Server{ wgKey: testingServerKey, config: &Config{ @@ -397,7 +391,6 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, error) { - lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { return nil, err @@ -408,7 +401,10 @@ func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, erro return nil, err } peersUpdateManager := NewPeersUpdateManager() - accountManager := NewManager(store, peersUpdateManager, nil) + accountManager, err := BuildManager(store, peersUpdateManager, nil) + if err != nil { + return nil, err + } turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager) if err != nil { diff --git a/management/server/management_test.go b/management/server/management_test.go index 379ccd559..f13e10fd0 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -2,8 +2,6 @@ package server_test import ( "context" - server "github.com/netbirdio/netbird/management/server" - "google.golang.org/grpc/credentials/insecure" "io/ioutil" "math/rand" "net" @@ -13,6 +11,9 @@ import ( sync2 "sync" "time" + server "github.com/netbirdio/netbird/management/server" + "google.golang.org/grpc/credentials/insecure" + pb "github.com/golang/protobuf/proto" //nolint "github.com/netbirdio/netbird/encryption" log "github.com/sirupsen/logrus" @@ -31,7 +32,6 @@ const ( ) var _ = Describe("Management service", func() { - var ( addr string s *grpc.Server @@ -66,7 +66,6 @@ var _ = Describe("Management service", func() { Expect(err).NotTo(HaveOccurred()) serverPubKey, err = wgtypes.ParseKey(resp.Key) Expect(err).NotTo(HaveOccurred()) - }) AfterEach(func() { @@ -78,7 +77,6 @@ var _ = Describe("Management service", func() { Context("when calling IsHealthy endpoint", func() { Specify("a non-error result is returned", func() { - healthy, err := client.IsHealthy(context.TODO(), &mgmtProto.Empty{}) Expect(err).NotTo(HaveOccurred()) @@ -87,7 +85,6 @@ var _ = Describe("Management service", func() { }) Context("when calling Sync endpoint", func() { - Context("when there is a new peer registered", func() { Specify("a proper configuration is returned", func() { key, _ := wgtypes.GenerateKey() @@ -168,7 +165,6 @@ var _ = Describe("Management service", func() { Expect(resp.GetRemotePeers()).To(HaveLen(2)) peers := []string{resp.GetRemotePeers()[0].WgPubKey, resp.GetRemotePeers()[1].WgPubKey} Expect(peers).To(ContainElements(key1.PublicKey().String(), key2.PublicKey().String())) - }) }) @@ -211,7 +207,6 @@ var _ = Describe("Management service", func() { resp = &mgmtProto.SyncResponse{} err = pb.Unmarshal(decryptedBytes, resp) wg.Done() - }() // register a new peer @@ -229,7 +224,6 @@ var _ = Describe("Management service", func() { Context("when calling GetServerKey endpoint", func() { Specify("a public Wireguard key of the service is returned", func() { - resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) Expect(err).NotTo(HaveOccurred()) @@ -237,19 +231,16 @@ var _ = Describe("Management service", func() { Expect(resp.Key).ToNot(BeNil()) Expect(resp.ExpiresAt).ToNot(BeNil()) - //check if the key is a valid Wireguard key + // check if the key is a valid Wireguard key key, err := wgtypes.ParseKey(resp.Key) Expect(err).NotTo(HaveOccurred()) Expect(key).ToNot(BeNil()) - }) }) Context("when calling Login endpoint", func() { - Context("with an invalid setup key", func() { Specify("an error is returned", func() { - key, _ := wgtypes.GenerateKey() message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{SetupKey: "invalid setup key"}) Expect(err).NotTo(HaveOccurred()) @@ -261,24 +252,20 @@ var _ = Describe("Management service", func() { Expect(err).To(HaveOccurred()) Expect(resp).To(BeNil()) - }) }) Context("with a valid setup key", func() { It("a non error result is returned", func() { - key, _ := wgtypes.GenerateKey() resp := loginPeerWithValidSetupKey(serverPubKey, key, client) Expect(resp).ToNot(BeNil()) - }) }) Context("with a registered peer", func() { It("a non error result is returned", func() { - key, _ := wgtypes.GenerateKey() regResp := loginPeerWithValidSetupKey(serverPubKey, key, client) Expect(regResp).NotTo(BeNil()) @@ -324,7 +311,6 @@ var _ = Describe("Management service", func() { Context("when there are 50 peers registered under one account", func() { Context("when there are 10 more peers registered under the same account", func() { Specify("all of the 50 peers will get updates of 10 newly registered peers", func() { - initialPeers := 20 additionalPeers := 10 @@ -369,7 +355,7 @@ var _ = Describe("Management service", func() { err = pb.Unmarshal(decryptedBytes, resp) Expect(err).NotTo(HaveOccurred()) if len(resp.GetRemotePeers()) > 0 { - //only consider peer updates + // only consider peer updates wg.Done() } } @@ -397,7 +383,6 @@ var _ = Describe("Management service", func() { Context("when there are peers registered under one account concurrently", func() { Specify("then there are no duplicate IPs", func() { - initialPeers := 30 ipChannel := make(chan string, 20) @@ -423,7 +408,6 @@ var _ = Describe("Management service", func() { Expect(err).NotTo(HaveOccurred()) ipChannel <- resp.GetPeerConfig().Address - }() } @@ -443,6 +427,7 @@ var _ = Describe("Management service", func() { }) func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { + defer GinkgoRecover() meta := &mgmtProto.PeerSystemMeta{ Hostname: key.PublicKey().String(), @@ -467,7 +452,6 @@ func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, clien err = encryption.DecryptMessage(serverPubKey, key, resp.Body, loginResp) Expect(err).NotTo(HaveOccurred()) return loginResp - } func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) { @@ -496,7 +480,10 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } peersUpdateManager := server.NewPeersUpdateManager() - accountManager := server.NewManager(store, peersUpdateManager, nil) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil) + if err != nil { + log.Fatalf("failed creating a manager: %v", err) + } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager) Expect(err).NotTo(HaveOccurred()) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 036d0643f..465b0ceb2 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -33,6 +33,10 @@ type MockAccountManager struct { GroupAddPeerFunc func(accountID, groupID, peerKey string) error GroupDeletePeerFunc func(accountID, groupID, peerKey string) error GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) + GetRuleFunc func(accountID, ruleID string) (*server.Rule, error) + SaveRuleFunc func(accountID string, rule *server.Rule) error + DeleteRuleFunc func(accountID, ruleID string) error + ListRulesFunc func(accountID string) ([]*server.Rule, error) GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error) } @@ -41,7 +45,6 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.U return am.GetUsersFromAccountFunc(accountID) } return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount not implemented") - } func (am *MockAccountManager) GetOrCreateAccountByUser( @@ -207,7 +210,7 @@ func (am *MockAccountManager) SaveGroup(accountID string, group *server.Group) e if am.SaveGroupFunc != nil { return am.SaveGroupFunc(accountID, group) } - return status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented") + return status.Errorf(codes.Unimplemented, "method SaveGroup not implemented") } func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error { @@ -244,3 +247,31 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv } return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers not implemented") } + +func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) { + if am.GetRuleFunc != nil { + return am.GetRuleFunc(accountID, ruleID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetRule not implemented") +} + +func (am *MockAccountManager) SaveRule(accountID string, rule *server.Rule) error { + if am.SaveRuleFunc != nil { + return am.SaveRuleFunc(accountID, rule) + } + return status.Errorf(codes.Unimplemented, "method SaveRule not implemented") +} + +func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error { + if am.DeleteRuleFunc != nil { + return am.DeleteRuleFunc(accountID, ruleID) + } + return status.Errorf(codes.Unimplemented, "method DeleteRule not implemented") +} + +func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) { + if am.ListRulesFunc != nil { + return am.ListRulesFunc(accountID) + } + return nil, status.Errorf(codes.Unimplemented, "method ListRules not implemented") +} diff --git a/management/server/peer.go b/management/server/peer.go index 1afae6f35..9fda87f91 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1,12 +1,13 @@ package server import ( - "github.com/netbirdio/netbird/management/proto" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "net" "strings" "time" + + "github.com/netbirdio/netbird/management/proto" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // PeerSystemMeta is a metadata of a Peer machine system @@ -21,31 +22,31 @@ type PeerSystemMeta struct { } type PeerStatus struct { - //LastSeen is the last time peer was connected to the management service + // LastSeen is the last time peer was connected to the management service LastSeen time.Time - //Connected indicates whether peer is connected to the management service or not + // Connected indicates whether peer is connected to the management service or not Connected bool } -//Peer represents a machine connected to the network. -//The Peer is a Wireguard peer identified by a public key +// Peer represents a machine connected to the network. +// The Peer is a Wireguard peer identified by a public key type Peer struct { - //Wireguard public key + // Wireguard public key Key string - //A setup key this peer was registered with + // A setup key this peer was registered with SetupKey string - //IP address of the Peer + // IP address of the Peer IP net.IP - //Meta is a Peer system meta data + // Meta is a Peer system meta data Meta PeerSystemMeta - //Name is peer's name (machine name) + // Name is peer's name (machine name) Name string Status *PeerStatus - //The user ID that registered the peer + // The user ID that registered the peer UserID string } -//Copy copies Peer object +// Copy copies Peer object func (p *Peer) Copy() *Peer { return &Peer{ Key: p.Key, @@ -58,7 +59,7 @@ func (p *Peer) Copy() *Peer { } } -//GetPeer returns a peer from a Store +// GetPeer returns a peer from a Store func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) { am.mux.Lock() defer am.mux.Unlock() @@ -71,7 +72,7 @@ func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) { return peer, nil } -//MarkPeerConnected marks peer as connected (true) or disconnected (false) +// MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected bool) error { am.mux.Lock() defer am.mux.Unlock() @@ -96,8 +97,12 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected boo return nil } -//RenamePeer changes peer's name -func (am *DefaultAccountManager) RenamePeer(accountId string, peerKey string, newName string) (*Peer, error) { +// RenamePeer changes peer's name +func (am *DefaultAccountManager) RenamePeer( + accountId string, + peerKey string, + newName string, +) (*Peer, error) { am.mux.Lock() defer am.mux.Unlock() @@ -116,7 +121,7 @@ func (am *DefaultAccountManager) RenamePeer(accountId string, peerKey string, ne return peerCopy, nil } -//DeletePeer removes peer from the account by it's IP +// DeletePeer removes peer from the account by it's IP func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*Peer, error) { am.mux.Lock() defer am.mux.Unlock() @@ -149,12 +154,13 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (* RemotePeers: []*proto.RemotePeerConfig{}, RemotePeersIsEmpty: true, }, - }}) + }, + }) if err != nil { return nil, err } - //notify other peers of the change + // notify other peers of the change peers, err := am.Store.GetAccountPeers(accountId) if err != nil { return nil, err @@ -180,7 +186,8 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (* RemotePeers: update, RemotePeersIsEmpty: len(update) == 0, }, - }}) + }, + }) if err != nil { return nil, err } @@ -190,7 +197,7 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (* return peer, nil } -//GetPeerByIP returns peer by it's IP +// GetPeerByIP returns peer by it's IP func (am *DefaultAccountManager) GetPeerByIP(accountId string, peerIP string) (*Peer, error) { am.mux.Lock() defer am.mux.Unlock() @@ -220,10 +227,46 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err } var res []*Peer - for _, peer := range account.Peers { - // exclude original peer - if peer.Key != peerKey { - res = append(res, peer.Copy()) + srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey) + if err != nil { + return &NetworkMap{ + Peers: res, + Network: account.Network.Copy(), + }, nil + } + + dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey) + if err != nil { + return &NetworkMap{ + Peers: res, + Network: account.Network.Copy(), + }, nil + } + + groups := map[string]*Group{} + for _, r := range srcRules { + if r.Flow == TrafficFlowBidirect { + for _, gid := range r.Destination { + groups[gid] = account.Groups[gid] + } + } + } + + for _, r := range dstRules { + if r.Flow == TrafficFlowBidirect { + for _, gid := range r.Source { + groups[gid] = account.Groups[gid] + } + } + } + + for _, g := range groups { + for _, pid := range g.Peers { + peer := account.Peers[pid] + // exclude original peer + if peer.Key != peerKey { + res = append(res, peer.Copy()) + } } } @@ -240,7 +283,11 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err // to it. We also add the User ID to the peer metadata to identify registrant. // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) { +func (am *DefaultAccountManager) AddPeer( + setupKey string, + userID string, + peer *Peer, +) (*Peer, error) { am.mux.Lock() defer am.mux.Unlock() @@ -252,17 +299,28 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P if len(upperKey) != 0 { account, err = am.Store.GetAccountBySetupKey(upperKey) if err != nil { - return nil, status.Errorf(codes.NotFound, "unable to register peer, unable to find account with setupKey %s", upperKey) + return nil, status.Errorf( + codes.NotFound, + "unable to register peer, unable to find account with setupKey %s", + upperKey, + ) } sk = getAccountSetupKeyByKey(account, upperKey) if sk == nil { // shouldn't happen actually - return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown setupKey %s", upperKey) + return nil, status.Errorf( + codes.NotFound, + "unable to register peer, unknown setupKey %s", + upperKey, + ) } if !sk.IsValid() { - return nil, status.Errorf(codes.FailedPrecondition, "unable to register peer, its setup key is invalid (expired, overused or revoked)") + return nil, status.Errorf( + codes.FailedPrecondition, + "unable to register peer, its setup key is invalid (expired, overused or revoked)", + ) } } else if len(userID) != 0 { @@ -293,6 +351,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P Status: &PeerStatus{Connected: false, LastSeen: time.Now()}, } + // add peer to 'All' group + group, err := account.GetGroupAll() + if err != nil { + return nil, err + } + group.Peers = append(group.Peers, newPeer.Key) + account.Peers[newPeer.Key] = newPeer if len(upperKey) != 0 { account.SetupKeys[sk.Key] = sk.IncrementUsage() @@ -305,5 +370,4 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P } return newPeer, nil - } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 838ff4f14..ba87ca773 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1,8 +1,10 @@ package server import ( - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "testing" + + "github.com/rs/xid" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) func TestAccountManager_GetNetworkMap(t *testing.T) { @@ -70,7 +72,151 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { } if networkMap.Peers[0].Key != peerKey2.PublicKey().String() { - t.Errorf("expecting Account NetworkMap to have peer with a key %s, got %s", peerKey2.PublicKey().String(), networkMap.Peers[0].Key) + t.Errorf( + "expecting Account NetworkMap to have peer with a key %s, got %s", + peerKey2.PublicKey().String(), + networkMap.Peers[0].Key, + ) + } +} + +func TestAccountManager_GetNetworkMapWithRule(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return } + expectedId := "test_account" + userId := "account_creator" + account, err := manager.AddAccount(expectedId, userId, "") + if err != nil { + t.Fatal(err) + } + + var setupKey *SetupKey + for _, key := range account.SetupKeys { + if key.Type == SetupKeyReusable { + setupKey = key + } + } + + peerKey1, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + _, err = manager.AddPeer(setupKey.Key, "", &Peer{ + Key: peerKey1.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer-2", + }) + + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + peerKey2, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + _, err = manager.AddPeer(setupKey.Key, "", &Peer{ + Key: peerKey2.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer-2", + }) + + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + rules, err := manager.ListRules(account.Id) + if err != nil { + t.Errorf("expecting to get a list of rules, got failure %v", err) + return + } + + err = manager.DeleteRule(account.Id, rules[0].ID) + if err != nil { + t.Errorf("expecting to delete 1 group, got failure %v", err) + return + } + var ( + group1 Group + group2 Group + rule Rule + ) + + group1.ID = xid.New().String() + group2.ID = xid.New().String() + group1.Name = "src" + group2.Name = "dst" + rule.ID = xid.New().String() + group1.Peers = append(group1.Peers, peerKey1.PublicKey().String()) + group2.Peers = append(group2.Peers, peerKey2.PublicKey().String()) + + err = manager.SaveGroup(account.Id, &group1) + if err != nil { + t.Errorf("expecting group1 to be added, got failure %v", err) + return + } + err = manager.SaveGroup(account.Id, &group2) + if err != nil { + t.Errorf("expecting group2 to be added, got failure %v", err) + return + } + + rule.Name = "test" + rule.Source = append(rule.Source, group1.ID) + rule.Destination = append(rule.Destination, group2.ID) + rule.Flow = TrafficFlowBidirect + err = manager.SaveRule(account.Id, &rule) + if err != nil { + t.Errorf("expecting rule to be added, got failure %v", err) + return + } + + networkMap1, err := manager.GetNetworkMap(peerKey1.PublicKey().String()) + if err != nil { + t.Fatal(err) + return + } + + if len(networkMap1.Peers) != 1 { + t.Errorf( + "expecting Account NetworkMap to have 1 peers, got %v: %v", + len(networkMap1.Peers), + networkMap1.Peers, + ) + } + + if networkMap1.Peers[0].Key != peerKey2.PublicKey().String() { + t.Errorf( + "expecting Account NetworkMap to have peer with a key %s, got %s", + peerKey2.PublicKey().String(), + networkMap1.Peers[0].Key, + ) + } + + networkMap2, err := manager.GetNetworkMap(peerKey2.PublicKey().String()) + if err != nil { + t.Fatal(err) + return + } + + if len(networkMap2.Peers) != 1 { + t.Errorf("expecting Account NetworkMap to have 1 peers, got %v", len(networkMap2.Peers)) + } + + if len(networkMap2.Peers) > 0 && networkMap2.Peers[0].Key != peerKey1.PublicKey().String() { + t.Errorf( + "expecting Account NetworkMap to have peer with a key %s, got %s", + peerKey1.PublicKey().String(), + networkMap2.Peers[0].Key, + ) + } } diff --git a/management/server/rule.go b/management/server/rule.go new file mode 100644 index 000000000..3cbf172d7 --- /dev/null +++ b/management/server/rule.go @@ -0,0 +1,107 @@ +package server + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// TrafficFlowType defines allowed direction of the traffic in the rule +type TrafficFlowType int + +const ( + // TrafficFlowBidirect allows traffic to both direction + TrafficFlowBidirect TrafficFlowType = iota +) + +// Rule of ACL for groups +type Rule struct { + // ID of the rule + ID string + + // Name of the rule visible in the UI + Name string + + // Source list of groups IDs of peers + Source []string + + // Destination list of groups IDs of peers + Destination []string + + // Flow of the traffic allowed by the rule + Flow TrafficFlowType +} + +func (r *Rule) Copy() *Rule { + return &Rule{ + ID: r.ID, + Name: r.Name, + Source: r.Source[:], + Destination: r.Destination[:], + Flow: r.Flow, + } +} + +// GetRule of ACL from the store +func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error) { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(codes.NotFound, "account not found") + } + + rule, ok := account.Rules[ruleID] + if ok { + return rule, nil + } + + return nil, status.Errorf(codes.NotFound, "rule with ID %s not found", ruleID) +} + +// SaveRule of ACL in the store +func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return status.Errorf(codes.NotFound, "account not found") + } + + account.Rules[rule.ID] = rule + return am.Store.SaveAccount(account) +} + +// DeleteRule of ACL from the store +func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return status.Errorf(codes.NotFound, "account not found") + } + + delete(account.Rules, ruleID) + + return am.Store.SaveAccount(account) +} + +// ListRules of ACL from the store +func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(codes.NotFound, "account not found") + } + + rules := make([]*Rule, 0, len(account.Rules)) + for _, item := range account.Rules { + rules = append(rules, item) + } + + return rules, nil +} diff --git a/management/server/store.go b/management/server/store.go index 58fbc45b0..3ce548d9b 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -4,10 +4,13 @@ type Store interface { GetPeer(peerKey string) (*Peer, error) DeletePeer(accountId string, peerKey string) (*Peer, error) SavePeer(accountId string, peer *Peer) error + GetAllAccounts() []*Account GetAccount(accountId string) (*Account, error) GetUserAccount(userId string) (*Account, error) GetAccountPeers(accountId string) ([]*Peer, error) GetPeerAccount(peerKey string) (*Account, error) + GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error) + GetPeerDstRules(accountId, peerKey string) ([]*Rule, error) GetAccountBySetupKey(setupKey string) (*Account, error) GetAccountByPrivateDomain(domain string) (*Account, error) SaveAccount(account *Account) error diff --git a/management/server/user.go b/management/server/user.go index 55794b333..9b9db6b8e 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -58,6 +58,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { account = NewAccount(userId, lowerDomain) account.Users[userId] = NewAdminUser(userId) + am.addAllGroup(account) err = am.Store.SaveAccount(account) if err != nil { return nil, status.Errorf(codes.Internal, "failed creating account")