mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client, management] Feature/ssh fine grained access (#4969)
Add fine-grained SSH access control with authorized users/groups
This commit is contained in:
@@ -1151,6 +1151,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||||
log.Warnf("failed to update SSH client config: %v", err)
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||||
}
|
}
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
|
|||||||
@@ -11,15 +11,18 @@ import (
|
|||||||
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sshServer interface {
|
type sshServer interface {
|
||||||
Start(ctx context.Context, addr netip.AddrPort) error
|
Start(ctx context.Context, addr netip.AddrPort) error
|
||||||
Stop() error
|
Stop() error
|
||||||
GetStatus() (bool, []sshserver.SessionInfo)
|
GetStatus() (bool, []sshserver.SessionInfo)
|
||||||
|
UpdateSSHAuth(config *sshauth.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) setupSSHPortRedirection() error {
|
func (e *Engine) setupSSHPortRedirection() error {
|
||||||
@@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
|
|||||||
|
|
||||||
return sshServer.GetStatus()
|
return sshServer.GetStatus()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
|
||||||
|
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
|
||||||
|
if sshAuth == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.sshServer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
protoUsers := sshAuth.GetAuthorizedUsers()
|
||||||
|
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||||
|
for i, hash := range protoUsers {
|
||||||
|
if len(hash) != 16 {
|
||||||
|
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
machineUsers := make(map[string][]uint32)
|
||||||
|
for osUser, indexes := range sshAuth.GetMachineUsers() {
|
||||||
|
machineUsers[osUser] = indexes.GetIndexes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update SSH server with new authorization configuration
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshAuth.GetUserIDClaim(),
|
||||||
|
AuthorizedUsers: authorizedUsers,
|
||||||
|
MachineUsers: machineUsers,
|
||||||
|
}
|
||||||
|
|
||||||
|
e.sshServer.UpdateSSHAuth(authConfig)
|
||||||
|
}
|
||||||
|
|||||||
184
client/ssh/auth/auth.go
Normal file
184
client/ssh/auth/auth.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultUserIDClaim is the default JWT claim used to extract user IDs
|
||||||
|
DefaultUserIDClaim = "sub"
|
||||||
|
// Wildcard is a special user ID that matches all users
|
||||||
|
Wildcard = "*"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||||
|
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||||
|
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||||
|
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Authorizer handles SSH fine-grained access control authorization
|
||||||
|
type Authorizer struct {
|
||||||
|
// UserIDClaim is the JWT claim to extract the user ID from
|
||||||
|
userIDClaim string
|
||||||
|
|
||||||
|
// authorizedUsers is a list of hashed user IDs authorized to access this peer
|
||||||
|
authorizedUsers []sshuserhash.UserIDHash
|
||||||
|
|
||||||
|
// machineUsers maps OS login usernames to lists of authorized user indexes
|
||||||
|
machineUsers map[string][]uint32
|
||||||
|
|
||||||
|
// mu protects the list of users
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config contains configuration for the SSH authorizer
|
||||||
|
type Config struct {
|
||||||
|
// UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
|
||||||
|
UserIDClaim string
|
||||||
|
|
||||||
|
// AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
|
||||||
|
AuthorizedUsers []sshuserhash.UserIDHash
|
||||||
|
|
||||||
|
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
|
||||||
|
// If a user wants to login as a specific OS user, their index must be in the corresponding list
|
||||||
|
MachineUsers map[string][]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthorizer creates a new SSH authorizer with empty configuration
|
||||||
|
func NewAuthorizer() *Authorizer {
|
||||||
|
a := &Authorizer{
|
||||||
|
userIDClaim: DefaultUserIDClaim,
|
||||||
|
machineUsers: make(map[string][]uint32),
|
||||||
|
}
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update updates the authorizer configuration with new values
|
||||||
|
func (a *Authorizer) Update(config *Config) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
// Clear authorization
|
||||||
|
a.userIDClaim = DefaultUserIDClaim
|
||||||
|
a.authorizedUsers = []sshuserhash.UserIDHash{}
|
||||||
|
a.machineUsers = make(map[string][]uint32)
|
||||||
|
log.Info("SSH authorization cleared")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDClaim := config.UserIDClaim
|
||||||
|
if userIDClaim == "" {
|
||||||
|
userIDClaim = DefaultUserIDClaim
|
||||||
|
}
|
||||||
|
a.userIDClaim = userIDClaim
|
||||||
|
|
||||||
|
// Store authorized users list
|
||||||
|
a.authorizedUsers = config.AuthorizedUsers
|
||||||
|
|
||||||
|
// Store machine users mapping
|
||||||
|
machineUsers := make(map[string][]uint32)
|
||||||
|
for osUser, indexes := range config.MachineUsers {
|
||||||
|
if len(indexes) > 0 {
|
||||||
|
machineUsers[osUser] = indexes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a.machineUsers = machineUsers
|
||||||
|
|
||||||
|
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
|
||||||
|
len(config.AuthorizedUsers), len(machineUsers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorize validates if a user is authorized to login as the specified OS user
|
||||||
|
// Returns nil if authorized, or an error describing why authorization failed
|
||||||
|
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
|
||||||
|
if jwtUserID == "" {
|
||||||
|
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
|
||||||
|
return ErrEmptyUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash the JWT user ID for comparison
|
||||||
|
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
|
||||||
|
return fmt.Errorf("failed to hash user ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a.mu.RLock()
|
||||||
|
defer a.mu.RUnlock()
|
||||||
|
|
||||||
|
// Find the index of this user in the authorized list
|
||||||
|
userIndex, found := a.findUserIndex(hashedUserID)
|
||||||
|
if !found {
|
||||||
|
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
|
||||||
|
return ErrUserNotAuthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
|
||||||
|
// Checks wildcard mapping first, then specific OS user mappings
|
||||||
|
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
|
||||||
|
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
|
||||||
|
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
|
||||||
|
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
|
||||||
|
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for specific OS username mapping
|
||||||
|
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
|
||||||
|
if !hasMachineUserMapping {
|
||||||
|
// No mapping for this OS user - deny by default (fail closed)
|
||||||
|
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
|
||||||
|
return ErrNoMachineUserMapping
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if user's index is in the allowed indexes for this specific OS user
|
||||||
|
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
|
||||||
|
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
|
||||||
|
return ErrUserNotMappedToOSUser
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserIDClaim returns the JWT claim name used to extract user IDs
|
||||||
|
func (a *Authorizer) GetUserIDClaim() string {
|
||||||
|
a.mu.RLock()
|
||||||
|
defer a.mu.RUnlock()
|
||||||
|
return a.userIDClaim
|
||||||
|
}
|
||||||
|
|
||||||
|
// findUserIndex finds the index of a hashed user ID in the authorized users list
|
||||||
|
// Returns the index and true if found, 0 and false if not found
|
||||||
|
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
|
||||||
|
for i, id := range a.authorizedUsers {
|
||||||
|
if id == hashedUserID {
|
||||||
|
return i, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIndexInList checks if an index exists in a list of indexes
|
||||||
|
func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
|
||||||
|
for _, idx := range indexes {
|
||||||
|
if idx == index {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
612
client/ssh/auth/auth_test.go
Normal file
612
client/ssh/auth/auth_test.go
Normal file
@@ -0,0 +1,612 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/sshauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users list with one user
|
||||||
|
authorizedUserHash, err := sshauth.HashUserID("authorized-user")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Try to authorize a different user
|
||||||
|
err = authorizer.Authorize("unauthorized-user", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// All attempts should fail when no machine user mappings exist (fail closed)
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user3Hash, err := sshauth.HashUserID("user3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0, 1}, // user1 and user2 can access root
|
||||||
|
"postgres": {1, 2}, // user2 and user3 can access postgres
|
||||||
|
"admin": {0}, // only user1 can access admin
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 (index 0) should access root and admin
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user1", "admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// user2 (index 1) should access root and postgres
|
||||||
|
err = authorizer.Authorize("user2", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// user3 (index 2) should access postgres
|
||||||
|
err = authorizer.Authorize("user3", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users list
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user3Hash, err := sshauth.HashUserID("user3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0, 1}, // user1 and user2 can access root
|
||||||
|
"postgres": {1, 2}, // user2 and user3 can access postgres
|
||||||
|
"admin": {0}, // only user1 can access admin
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 (index 0) should NOT access postgres
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// user2 (index 1) should NOT access admin
|
||||||
|
err = authorizer.Authorize("user2", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// user3 (index 2) should NOT access root
|
||||||
|
err = authorizer.Authorize("user3", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// user3 (index 2) should NOT access admin
|
||||||
|
err = authorizer.Authorize("user3", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users list
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0}, // only root is mapped
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 should NOT access an unmapped OS user (fail closed)
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users list
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Empty user ID should fail
|
||||||
|
err = authorizer.Authorize("", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrEmptyUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up multiple authorized users
|
||||||
|
userHashes := make([]sshauth.UserIDHash, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
userHashes[i] = hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create machine user mapping for all users
|
||||||
|
rootIndexes := make([]uint32, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
rootIndexes[i] = uint32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: userHashes,
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": rootIndexes,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// All users should be authorized for root
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
|
||||||
|
assert.NoError(t, err, "user%d should be authorized", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// User not in list should fail
|
||||||
|
err := authorizer.Authorize("unknown-user", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up initial configuration
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{"root": {0}},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 should be authorized
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Clear configuration
|
||||||
|
authorizer.Update(nil)
|
||||||
|
|
||||||
|
// user1 should no longer be authorized
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Machine users with empty index lists should be filtered out
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0},
|
||||||
|
"postgres": {}, // empty list - should be filtered out
|
||||||
|
"admin": nil, // nil list - should be filtered out
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// root should work
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// postgres should fail (no mapping)
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
// admin should fail (no mapping)
|
||||||
|
err = authorizer.Authorize("user1", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up with custom user ID claim
|
||||||
|
user1Hash, err := sshauth.HashUserID("user@example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: "email",
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Verify the custom claim is set
|
||||||
|
assert.Equal(t, "email", authorizer.GetUserIDClaim())
|
||||||
|
|
||||||
|
// Authorize with email as user ID
|
||||||
|
err = authorizer.Authorize("user@example.com", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Verify default claim
|
||||||
|
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
|
||||||
|
assert.Equal(t, "sub", authorizer.GetUserIDClaim())
|
||||||
|
|
||||||
|
// Set up with empty user ID claim (should use default)
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: "", // empty - should use default
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Should fall back to default
|
||||||
|
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Create a large authorized users list
|
||||||
|
const numUsers = 1000
|
||||||
|
userHashes := make([]sshauth.UserIDHash, numUsers)
|
||||||
|
for i := 0; i < numUsers; i++ {
|
||||||
|
hash, err := sshauth.HashUserID("user" + string(rune(i)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
userHashes[i] = hash
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: userHashes,
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0, 500, 999}, // first, middle, and last user
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// First user should have access
|
||||||
|
err := authorizer.Authorize("user"+string(rune(0)), "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Middle user should have access
|
||||||
|
err = authorizer.Authorize("user"+string(rune(500)), "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Last user should have access
|
||||||
|
err = authorizer.Authorize("user"+string(rune(999)), "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// User not in mapping should NOT have access
|
||||||
|
err = authorizer.Authorize("user"+string(rune(100)), "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Set up authorized users
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0, 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Test concurrent authorization calls (should be safe to read concurrently)
|
||||||
|
const numGoroutines = 100
|
||||||
|
errChan := make(chan error, numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func(idx int) {
|
||||||
|
user := "user1"
|
||||||
|
if idx%2 == 0 {
|
||||||
|
user = "user2"
|
||||||
|
}
|
||||||
|
err := authorizer.Authorize(user, "root")
|
||||||
|
errChan <- err
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines to complete and collect errors
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
err := <-errChan
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user3Hash, err := sshauth.HashUserID("user3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure with wildcard - all authorized users can access any OS user
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"*": {0, 1, 2}, // wildcard with all user indexes
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// All authorized users should be able to access any OS user
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user3", "admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user1", "ubuntu")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "nginx")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user3", "docker")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure with wildcard
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"*": {0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 should have access
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Unauthorized user should still be denied even with wildcard
|
||||||
|
err = authorizer.Authorize("unauthorized-user", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure with both wildcard and specific mappings
|
||||||
|
// Wildcard takes precedence for users in the wildcard index list
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"*": {0, 1}, // wildcard for both users
|
||||||
|
"root": {0}, // specific mapping that would normally restrict to user1 only
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Both users should be able to access any other OS user via wildcard
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "admin")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
user1Hash, err := sshauth.HashUserID("user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure WITHOUT wildcard - only specific mappings
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"root": {0}, // only user1
|
||||||
|
"postgres": {1}, // only user2
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// user1 can access root
|
||||||
|
err = authorizer.Authorize("user1", "root")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// user2 can access postgres
|
||||||
|
err = authorizer.Authorize("user2", "postgres")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// user1 cannot access postgres
|
||||||
|
err = authorizer.Authorize("user1", "postgres")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// user2 cannot access root
|
||||||
|
err = authorizer.Authorize("user2", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
|
||||||
|
|
||||||
|
// Neither can access unmapped OS users
|
||||||
|
err = authorizer.Authorize("user1", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "admin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
||||||
|
// This test covers the scenario where wildcard exists with limited indexes.
|
||||||
|
// Only users whose indexes are in the wildcard list can access any OS user via wildcard.
|
||||||
|
// Other users can only access OS users they are explicitly mapped to.
|
||||||
|
authorizer := NewAuthorizer()
|
||||||
|
|
||||||
|
// Create two authorized user hashes (simulating the base64-encoded hashes in the config)
|
||||||
|
wasmHash, err := sshauth.HashUserID("wasm")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2Hash, err := sshauth.HashUserID("user2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Configure with wildcard having only index 0, and specific mappings for other OS users
|
||||||
|
config := &Config{
|
||||||
|
UserIDClaim: "sub",
|
||||||
|
AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
"*": {0}, // wildcard with only index 0 - only wasm has wildcard access
|
||||||
|
"alice": {1}, // specific mapping for user2
|
||||||
|
"bob": {1}, // specific mapping for user2
|
||||||
|
},
|
||||||
|
}
|
||||||
|
authorizer.Update(config)
|
||||||
|
|
||||||
|
// wasm (index 0) should access any OS user via wildcard
|
||||||
|
err = authorizer.Authorize("wasm", "root")
|
||||||
|
assert.NoError(t, err, "wasm should access root via wildcard")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("wasm", "alice")
|
||||||
|
assert.NoError(t, err, "wasm should access alice via wildcard")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("wasm", "bob")
|
||||||
|
assert.NoError(t, err, "wasm should access bob via wildcard")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("wasm", "postgres")
|
||||||
|
assert.NoError(t, err, "wasm should access postgres via wildcard")
|
||||||
|
|
||||||
|
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
|
||||||
|
err = authorizer.Authorize("user2", "alice")
|
||||||
|
assert.NoError(t, err, "user2 should access alice via explicit mapping")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "bob")
|
||||||
|
assert.NoError(t, err, "user2 should access bob via explicit mapping")
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "root")
|
||||||
|
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
err = authorizer.Authorize("user2", "postgres")
|
||||||
|
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
|
||||||
|
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
|
||||||
|
|
||||||
|
// Unauthorized user should still be denied
|
||||||
|
err = authorizer.Authorize("user3", "root")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||||
|
}
|
||||||
@@ -27,9 +27,11 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
"github.com/netbirdio/netbird/client/ssh/server"
|
"github.com/netbirdio/netbird/client/ssh/server"
|
||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
@@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) {
|
|||||||
sshServer := server.New(serverConfig)
|
sshServer := server.New(serverConfig)
|
||||||
sshServer.SetAllowRootLogin(true)
|
sshServer.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
// Configure SSH authorization for the test user
|
||||||
|
testUsername := testutil.GetTestUsername(t)
|
||||||
|
testJWTUser := "test-username"
|
||||||
|
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
testUsername: {0}, // Index 0 in AuthorizedUsers
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||||
defer func() { _ = sshServer.Stop() }()
|
defer func() { _ = sshServer.Stop() }()
|
||||||
|
|
||||||
@@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) {
|
|||||||
|
|
||||||
mockDaemon.setHostKey(host, hostPubKey)
|
mockDaemon.setHostKey(host, hostPubKey)
|
||||||
|
|
||||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||||
mockDaemon.setJWTToken(validToken)
|
mockDaemon.setJWTToken(validToken)
|
||||||
|
|
||||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
|
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
clientConn, proxyConn := net.Pipe()
|
clientConn, proxyConn := net.Pipe()
|
||||||
@@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
|||||||
return privateKey, jwksJSON
|
return privateKey, jwksJSON
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.MapClaims{
|
||||||
"iss": issuer,
|
"iss": issuer,
|
||||||
"aud": audience,
|
"aud": audience,
|
||||||
"sub": "test-user",
|
"sub": user,
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
"iat": time.Now().Unix(),
|
"iat": time.Now().Unix(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,10 +23,12 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
"github.com/netbirdio/netbird/client/ssh/client"
|
"github.com/netbirdio/netbird/client/ssh/client"
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJWTEnforcement(t *testing.T) {
|
func TestJWTEnforcement(t *testing.T) {
|
||||||
@@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) {
|
|||||||
tc.setupServer(server)
|
tc.setupServer(server)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always set up authorization for test-user to ensure tests fail at JWT validation stage
|
||||||
|
testUserHash, err := sshuserhash.HashUserID("test-user")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Get current OS username for machine user mapping
|
||||||
|
currentUser := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||||
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||||
|
MachineUsers: map[string][]uint32{
|
||||||
|
currentUser: {0}, // Allow test-user (index 0) to access current OS user
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.UpdateSSHAuth(authConfig)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer require.NoError(t, server.Stop())
|
defer require.NoError(t, server.Stop())
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
"github.com/netbirdio/netbird/shared/auth"
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
@@ -138,6 +139,8 @@ type Server struct {
|
|||||||
jwtExtractor *jwt.ClaimsExtractor
|
jwtExtractor *jwt.ClaimsExtractor
|
||||||
jwtConfig *JWTConfig
|
jwtConfig *JWTConfig
|
||||||
|
|
||||||
|
authorizer *sshauth.Authorizer
|
||||||
|
|
||||||
suSupportsPty bool
|
suSupportsPty bool
|
||||||
loginIsUtilLinux bool
|
loginIsUtilLinux bool
|
||||||
}
|
}
|
||||||
@@ -179,6 +182,7 @@ func New(config *Config) *Server {
|
|||||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||||
jwtEnabled: config.JWT != nil,
|
jwtEnabled: config.JWT != nil,
|
||||||
jwtConfig: config.JWT,
|
jwtConfig: config.JWT,
|
||||||
|
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
@@ -320,6 +324,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
|||||||
s.wgAddress = addr
|
s.wgAddress = addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateSSHAuth updates the SSH fine-grained access control configuration
|
||||||
|
// This should be called when network map updates include new SSH auth configuration
|
||||||
|
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Reset JWT validator/extractor to pick up new userIDClaim
|
||||||
|
s.jwtValidator = nil
|
||||||
|
s.jwtExtractor = nil
|
||||||
|
|
||||||
|
s.authorizer.Update(config)
|
||||||
|
}
|
||||||
|
|
||||||
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||||
func (s *Server) ensureJWTValidator() error {
|
func (s *Server) ensureJWTValidator() error {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
@@ -328,6 +345,7 @@ func (s *Server) ensureJWTValidator() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
config := s.jwtConfig
|
config := s.jwtConfig
|
||||||
|
authorizer := s.authorizer
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if config == nil {
|
if config == nil {
|
||||||
@@ -343,9 +361,16 @@ func (s *Server) ensureJWTValidator() error {
|
|||||||
true,
|
true,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor := jwt.NewClaimsExtractor(
|
// Use custom userIDClaim from authorizer if available
|
||||||
|
extractorOptions := []jwt.ClaimsExtractorOption{
|
||||||
jwt.WithAudience(config.Audience),
|
jwt.WithAudience(config.Audience),
|
||||||
)
|
}
|
||||||
|
if authorizer.GetUserIDClaim() != "" {
|
||||||
|
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
|
||||||
|
log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
|
||||||
|
}
|
||||||
|
|
||||||
|
extractor := jwt.NewClaimsExtractor(extractorOptions...)
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -493,29 +518,41 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||||
|
osUsername := ctx.User()
|
||||||
|
remoteAddr := ctx.RemoteAddr()
|
||||||
|
|
||||||
if err := s.ensureJWTValidator(); err != nil {
|
if err := s.ensureJWTValidator(); err != nil {
|
||||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.validateJWTToken(password)
|
token, err := s.validateJWTToken(password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
userAuth, err := s.extractAndValidateUser(token)
|
userAuth, err := s.extractAndValidateUser(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
|
s.mu.RLock()
|
||||||
|
authorizer := s.authorizer
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
|
||||||
|
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := newAuthKey(osUsername, remoteAddr)
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.pendingAuthJWT[key] = userAuth.UserId
|
s.pendingAuthJWT[key] = userAuth.UserId
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||||
@@ -224,7 +225,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||||
} else {
|
} else {
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||||
@@ -320,6 +321,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
postureChecks, err := c.getPeerPostureChecks(account, peerId)
|
postureChecks, err := c.getPeerPostureChecks(account, peerId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -338,7 +340,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
if c.experimentalNetworkMap(accountId) {
|
if c.experimentalNetworkMap(accountId) {
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||||
} else {
|
} else {
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
@@ -445,7 +447,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
if c.experimentalNetworkMap(accountID) {
|
if c.experimentalNetworkMap(accountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
|
||||||
} else {
|
} else {
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
@@ -811,7 +813,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||||
} else {
|
} else {
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
|
|||||||
@@ -158,5 +158,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
@@ -16,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/shared/sshauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||||
@@ -84,15 +88,15 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
|
|||||||
return nbConfig
|
return nbConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig {
|
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig {
|
||||||
netmask, _ := network.Net.Mask.Size()
|
netmask, _ := network.Net.Mask.Size()
|
||||||
fqdn := peer.FQDN(dnsName)
|
fqdn := peer.FQDN(dnsName)
|
||||||
|
|
||||||
sshConfig := &proto.SSHConfig{
|
sshConfig := &proto.SSHConfig{
|
||||||
SshEnabled: peer.SSHEnabled,
|
SshEnabled: peer.SSHEnabled || enableSSH,
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.SSHEnabled {
|
if sshConfig.SshEnabled {
|
||||||
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
|
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,12 +114,12 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
|||||||
|
|
||||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||||
response := &proto.SyncResponse{
|
response := &proto.SyncResponse{
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||||
NetworkMap: &proto.NetworkMap{
|
NetworkMap: &proto.NetworkMap{
|
||||||
Serial: networkMap.Network.CurrentSerial(),
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
Routes: toProtocolRoutes(networkMap.Routes),
|
Routes: toProtocolRoutes(networkMap.Routes),
|
||||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||||
},
|
},
|
||||||
Checks: toProtocolChecks(ctx, checks),
|
Checks: toProtocolChecks(ctx, checks),
|
||||||
}
|
}
|
||||||
@@ -151,9 +155,45 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
response.NetworkMap.ForwardingRules = forwardingRules
|
response.NetworkMap.ForwardingRules = forwardingRules
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if networkMap.AuthorizedUsers != nil {
|
||||||
|
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||||
|
userIDClaim := auth.DefaultUserIDClaim
|
||||||
|
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||||
|
userIDClaim = httpConfig.AuthUserIDClaim
|
||||||
|
}
|
||||||
|
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||||
|
}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||||
|
userIDToIndex := make(map[string]uint32)
|
||||||
|
var hashedUsers [][]byte
|
||||||
|
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||||
|
|
||||||
|
for machineUser, users := range authorizedUsers {
|
||||||
|
indexes := make([]uint32, 0, len(users))
|
||||||
|
for userID := range users {
|
||||||
|
idx, exists := userIDToIndex[userID]
|
||||||
|
if !exists {
|
||||||
|
hash, err := sshauth.HashUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idx = uint32(len(hashedUsers))
|
||||||
|
userIDToIndex[userID] = idx
|
||||||
|
hashedUsers = append(hashedUsers, hash[:])
|
||||||
|
}
|
||||||
|
indexes = append(indexes, idx)
|
||||||
|
}
|
||||||
|
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||||
|
}
|
||||||
|
|
||||||
|
return hashedUsers, machineUsers
|
||||||
|
}
|
||||||
|
|
||||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||||
for _, rPeer := range peers {
|
for _, rPeer := range peers {
|
||||||
dst = append(dst, &proto.RemotePeerConfig{
|
dst = append(dst, &proto.RemotePeerConfig{
|
||||||
|
|||||||
@@ -635,7 +635,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
|||||||
// if peer has reached this point then it has logged in
|
// if peer has reached this point then it has logged in
|
||||||
loginResp := &proto.LoginResponse{
|
loginResp := &proto.LoginResponse{
|
||||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow),
|
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
|
||||||
Checks: toProtocolChecks(ctx, postureChecks),
|
Checks: toProtocolChecks(ctx, postureChecks),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1456,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.GroupsPropagationEnabled {
|
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
||||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -397,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -299,7 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||||
}
|
}
|
||||||
@@ -369,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
|||||||
PortRanges: []types.RulePortRange{portRange},
|
PortRanges: []types.RulePortRange{portRange},
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
|
if protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||||
|
policy.Rules[0].AuthorizedUser = userAuth.UserId
|
||||||
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
|
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -449,6 +452,18 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
|||||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
Ephemeral: peer.Ephemeral,
|
Ephemeral: peer.Ephemeral,
|
||||||
|
LocalFlags: &api.PeerLocalFlags{
|
||||||
|
BlockInbound: &peer.Meta.Flags.BlockInbound,
|
||||||
|
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
|
||||||
|
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
|
||||||
|
DisableDns: &peer.Meta.Flags.DisableDNS,
|
||||||
|
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
|
||||||
|
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
|
||||||
|
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
|
||||||
|
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
|
||||||
|
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
|
||||||
|
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if !approved {
|
if !approved {
|
||||||
@@ -463,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
|||||||
if osVersion == "" {
|
if osVersion == "" {
|
||||||
osVersion = peer.Meta.Core
|
osVersion = peer.Meta.Core
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.PeerBatch{
|
return &api.PeerBatch{
|
||||||
CreatedAt: peer.CreatedAt,
|
CreatedAt: peer.CreatedAt,
|
||||||
Id: peer.ID,
|
Id: peer.ID,
|
||||||
@@ -492,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
|||||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
Ephemeral: peer.Ephemeral,
|
Ephemeral: peer.Ephemeral,
|
||||||
|
LocalFlags: &api.PeerLocalFlags{
|
||||||
|
BlockInbound: &peer.Meta.Flags.BlockInbound,
|
||||||
|
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
|
||||||
|
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
|
||||||
|
DisableDns: &peer.Meta.Flags.DisableDNS,
|
||||||
|
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
|
||||||
|
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
|
||||||
|
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
|
||||||
|
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
|
||||||
|
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
|
||||||
|
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
|||||||
pr.Protocol = types.PolicyRuleProtocolUDP
|
pr.Protocol = types.PolicyRuleProtocolUDP
|
||||||
case api.PolicyRuleUpdateProtocolIcmp:
|
case api.PolicyRuleUpdateProtocolIcmp:
|
||||||
pr.Protocol = types.PolicyRuleProtocolICMP
|
pr.Protocol = types.PolicyRuleProtocolICMP
|
||||||
|
case api.PolicyRuleUpdateProtocolNetbirdSsh:
|
||||||
|
pr.Protocol = types.PolicyRuleProtocolNetbirdSSH
|
||||||
default:
|
default:
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
|
||||||
return
|
return
|
||||||
@@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 {
|
||||||
|
for _, sourceGroupID := range pr.Sources {
|
||||||
|
_, ok := (*rule.AuthorizedGroups)[sourceGroupID]
|
||||||
|
if !ok {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pr.AuthorizedGroups = *rule.AuthorizedGroups
|
||||||
|
}
|
||||||
|
|
||||||
// validate policy object
|
// validate policy object
|
||||||
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
|
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
|
||||||
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
|
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
|
||||||
@@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
|
|||||||
DestinationResource: r.DestinationResource.ToAPIResponse(),
|
DestinationResource: r.DestinationResource.ToAPIResponse(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(r.AuthorizedGroups) != 0 {
|
||||||
|
authorizedGroupsCopy := r.AuthorizedGroups
|
||||||
|
rule.AuthorizedGroups = &authorizedGroupsCopy
|
||||||
|
}
|
||||||
|
|
||||||
if len(r.Ports) != 0 {
|
if len(r.Ports) != 0 {
|
||||||
portsCopy := r.Ports
|
portsCopy := r.Ports
|
||||||
rule.Ports = &portsCopy
|
rule.Ports = &portsCopy
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
|||||||
|
|
||||||
// fetch all the peers that have access to the user's peers
|
// fetch all the peers that have access to the user's peers
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
|
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers())
|
||||||
for _, p := range aclPeers {
|
for _, p := range aclPeers {
|
||||||
peersMap[p.ID] = p
|
peersMap[p.ID] = p
|
||||||
}
|
}
|
||||||
@@ -1057,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range userPeers {
|
for _, p := range userPeers {
|
||||||
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
|
aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers())
|
||||||
for _, aclPeer := range aclPeers {
|
for _, aclPeer := range aclPeers {
|
||||||
if aclPeer.ID == peer.ID {
|
if aclPeer.ID == peer.ID {
|
||||||
return peer, nil
|
return peer, nil
|
||||||
|
|||||||
@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("check that all peers get map", func(t *testing.T) {
|
t.Run("check that all peers get map", func(t *testing.T) {
|
||||||
for _, p := range account.Peers {
|
for _, p := range account.Peers {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
|
||||||
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
|
||||||
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check first peer map details", func(t *testing.T) {
|
t.Run("check first peer map details", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 8)
|
assert.Len(t, peers, 8)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
@@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
t.Run("check port ranges support for older peers", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 1)
|
assert.Len(t, peers, 1)
|
||||||
assert.Contains(t, peers, account.Peers["peerI"])
|
assert.Contains(t, peers, account.Peers["peerI"])
|
||||||
|
|
||||||
@@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Run("check first peer map", func(t *testing.T) {
|
t.Run("check first peer map", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
|
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check second peer map", func(t *testing.T) {
|
t.Run("check second peer map", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Contains(t, peers, account.Peers["peerB"])
|
assert.Contains(t, peers, account.Peers["peerB"])
|
||||||
|
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
account.Policies[1].Rules[0].Bidirectional = false
|
account.Policies[1].Rules[0].Bidirectional = false
|
||||||
|
|
||||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
|
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Contains(t, peers, account.Peers["peerB"])
|
assert.Contains(t, peers, account.Peers["peerB"])
|
||||||
|
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||||
// will establish a connection with all source peers satisfying the NB posture check.
|
// will establish a connection with all source peers satisfying the NB posture check.
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 4)
|
assert.Len(t, peers, 4)
|
||||||
assert.Len(t, firewallRules, 4)
|
assert.Len(t, firewallRules, 4)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
@@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||||
// We expect a single permissive firewall rule which all outgoing connections
|
// We expect a single permissive firewall rule which all outgoing connections
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||||
assert.Len(t, firewallRules, 7)
|
assert.Len(t, firewallRules, 7)
|
||||||
expectedFirewallRules := []*types.FirewallRule{
|
expectedFirewallRules := []*types.FirewallRule{
|
||||||
@@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||||
// all source group peers satisfying the NB posture check should establish connection
|
// all source group peers satisfying the NB posture check should establish connection
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 4)
|
assert.Len(t, peers, 4)
|
||||||
assert.Len(t, firewallRules, 4)
|
assert.Len(t, firewallRules, 4)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
@@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||||
// all source group peers satisfying the NB posture check should establish connection
|
// all source group peers satisfying the NB posture check should establish connection
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 4)
|
assert.Len(t, peers, 4)
|
||||||
assert.Len(t, firewallRules, 4)
|
assert.Len(t, firewallRules, 4)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
@@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||||
// no connection should be established to any peer of destination group
|
// no connection should be established to any peer of destination group
|
||||||
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
|
peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 0)
|
assert.Len(t, peers, 0)
|
||||||
assert.Len(t, firewallRules, 0)
|
assert.Len(t, firewallRules, 0)
|
||||||
|
|
||||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||||
// no connection should be established to any peer of destination group
|
// no connection should be established to any peer of destination group
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 0)
|
assert.Len(t, peers, 0)
|
||||||
assert.Len(t, firewallRules, 0)
|
assert.Len(t, firewallRules, 0)
|
||||||
|
|
||||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||||
// We expect a single permissive firewall rule which all outgoing connections
|
// We expect a single permissive firewall rule which all outgoing connections
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||||
|
|
||||||
@@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
|||||||
|
|
||||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||||
// all source group peers satisfying the NB posture check should establish connection
|
// all source group peers satisfying the NB posture check should establish connection
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 3)
|
assert.Len(t, peers, 3)
|
||||||
assert.Len(t, firewallRules, 3)
|
assert.Len(t, firewallRules, 3)
|
||||||
assert.Contains(t, peers, account.Peers["peerA"])
|
assert.Contains(t, peers, account.Peers["peerA"])
|
||||||
assert.Contains(t, peers, account.Peers["peerC"])
|
assert.Contains(t, peers, account.Peers["peerC"])
|
||||||
assert.Contains(t, peers, account.Peers["peerD"])
|
assert.Contains(t, peers, account.Peers["peerD"])
|
||||||
|
|
||||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
|
peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, peers, 5)
|
assert.Len(t, peers, 5)
|
||||||
// assert peers from Group Swarm
|
// assert peers from Group Swarm
|
||||||
assert.Contains(t, peers, account.Peers["peerD"])
|
assert.Contains(t, peers, account.Peers["peerD"])
|
||||||
|
|||||||
@@ -1910,16 +1910,16 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
|
|||||||
if len(policyIDs) == 0 {
|
if len(policyIDs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)`
|
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)`
|
||||||
rows, err := s.pool.Query(ctx, query, policyIDs)
|
rows, err := s.pool.Query(ctx, query, policyIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
|
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
|
||||||
var r types.PolicyRule
|
var r types.PolicyRule
|
||||||
var dest, destRes, sources, sourceRes, ports, portRanges []byte
|
var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte
|
||||||
var enabled, bidirectional sql.NullBool
|
var enabled, bidirectional sql.NullBool
|
||||||
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges)
|
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &r.AuthorizedUser)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if enabled.Valid {
|
if enabled.Valid {
|
||||||
r.Enabled = enabled.Bool
|
r.Enabled = enabled.Bool
|
||||||
@@ -1945,6 +1945,9 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
|
|||||||
if portRanges != nil {
|
if portRanges != nil {
|
||||||
_ = json.Unmarshal(portRanges, &r.PortRanges)
|
_ = json.Unmarshal(portRanges, &r.PortRanges)
|
||||||
}
|
}
|
||||||
|
if authorizedGroups != nil {
|
||||||
|
_ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &r, err
|
return &r, err
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
@@ -45,8 +46,10 @@ const (
|
|||||||
|
|
||||||
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
|
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
|
||||||
nativeSSHPortString = "22022"
|
nativeSSHPortString = "22022"
|
||||||
|
nativeSSHPortNumber = 22022
|
||||||
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
|
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
|
||||||
defaultSSHPortString = "22"
|
defaultSSHPortString = "22"
|
||||||
|
defaultSSHPortNumber = 22
|
||||||
)
|
)
|
||||||
|
|
||||||
type supportedFeatures struct {
|
type supportedFeatures struct {
|
||||||
@@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
resourcePolicies map[string][]*Policy,
|
resourcePolicies map[string][]*Policy,
|
||||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||||
metrics *telemetry.AccountManagerMetrics,
|
metrics *telemetry.AccountManagerMetrics,
|
||||||
|
groupIDToUserIDs map[string][]string,
|
||||||
) *NetworkMap {
|
) *NetworkMap {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
peer := a.Peers[peerID]
|
peer := a.Peers[peerID]
|
||||||
@@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||||
// exclude expired peers
|
// exclude expired peers
|
||||||
var peersToConnect []*nbpeer.Peer
|
var peersToConnect []*nbpeer.Peer
|
||||||
var expiredPeers []*nbpeer.Peer
|
var expiredPeers []*nbpeer.Peer
|
||||||
@@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
OfflinePeers: expiredPeers,
|
OfflinePeers: expiredPeers,
|
||||||
FirewallRules: firewallRules,
|
FirewallRules: firewallRules,
|
||||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||||
|
AuthorizedUsers: authorizedUsers,
|
||||||
|
EnableSSH: enableSSH,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metrics != nil {
|
if metrics != nil {
|
||||||
@@ -1009,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
|||||||
// GetPeerConnectionResources for a given peer
|
// GetPeerConnectionResources for a given peer
|
||||||
//
|
//
|
||||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
|
||||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||||
|
authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
|
||||||
|
sshEnabled := false
|
||||||
|
|
||||||
for _, policy := range a.Policies {
|
for _, policy := range a.Policies {
|
||||||
if !policy.Enabled {
|
if !policy.Enabled {
|
||||||
@@ -1053,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
|
|||||||
if peerInDestinations {
|
if peerInDestinations {
|
||||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||||
|
sshEnabled = true
|
||||||
|
switch {
|
||||||
|
case len(rule.AuthorizedGroups) > 0:
|
||||||
|
for groupID, localUsers := range rule.AuthorizedGroups {
|
||||||
|
userIDs, ok := groupIDToUserIDs[groupID]
|
||||||
|
if !ok {
|
||||||
|
log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(localUsers) == 0 {
|
||||||
|
localUsers = []string{auth.Wildcard}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, localUser := range localUsers {
|
||||||
|
if authorizedUsers[localUser] == nil {
|
||||||
|
authorizedUsers[localUser] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
for _, userID := range userIDs {
|
||||||
|
authorizedUsers[localUser][userID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case rule.AuthorizedUser != "":
|
||||||
|
if authorizedUsers[auth.Wildcard] == nil {
|
||||||
|
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
|
||||||
|
default:
|
||||||
|
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||||
|
}
|
||||||
|
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||||
|
sshEnabled = true
|
||||||
|
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return getAccumulatedResources()
|
peers, fwRules := getAccumulatedResources()
|
||||||
|
return peers, fwRules, authorizedUsers, sshEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) getAllowedUserIDs() map[string]struct{} {
|
||||||
|
users := make(map[string]struct{})
|
||||||
|
for _, nbUser := range a.Users {
|
||||||
|
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
|
||||||
|
users[nbUser.Id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return users
|
||||||
}
|
}
|
||||||
|
|
||||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||||
@@ -1081,12 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
|||||||
peersExists[peer.ID] = struct{}{}
|
peersExists[peer.ID] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protocol := rule.Protocol
|
||||||
|
if protocol == PolicyRuleProtocolNetbirdSSH {
|
||||||
|
protocol = PolicyRuleProtocolTCP
|
||||||
|
}
|
||||||
|
|
||||||
fr := FirewallRule{
|
fr := FirewallRule{
|
||||||
PolicyID: rule.ID,
|
PolicyID: rule.ID,
|
||||||
PeerIP: peer.IP.String(),
|
PeerIP: peer.IP.String(),
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
Action: string(rule.Action),
|
Action: string(rule.Action),
|
||||||
Protocol: string(rule.Protocol),
|
Protocol: string(protocol),
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||||
@@ -1108,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||||
|
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
|
||||||
|
for _, pr := range portRanges {
|
||||||
|
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func portsIncludesSSH(ports []string) bool {
|
||||||
|
for _, port := range ports {
|
||||||
|
if port == defaultSSHPortString || port == nativeSSHPortString {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// getAllPeersFromGroups for given peer ID and list of groups
|
// getAllPeersFromGroups for given peer ID and list of groups
|
||||||
//
|
//
|
||||||
// Returns a list of peers from specified groups that pass specified posture checks
|
// Returns a list of peers from specified groups that pass specified posture checks
|
||||||
@@ -1660,6 +1743,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetActiveGroupUsers() map[string][]string {
|
||||||
|
allGroupID := ""
|
||||||
|
group, err := a.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get group all: %v", err)
|
||||||
|
} else {
|
||||||
|
allGroupID = group.ID
|
||||||
|
}
|
||||||
|
groups := make(map[string][]string, len(a.GroupsG))
|
||||||
|
for _, user := range a.Users {
|
||||||
|
if !user.IsBlocked() && !user.IsServiceUser {
|
||||||
|
for _, groupID := range user.AutoGroups {
|
||||||
|
groups[groupID] = append(groups[groupID], user.Id)
|
||||||
|
}
|
||||||
|
groups[allGroupID] = append(groups[allGroupID], user.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||||
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
||||||
@@ -1691,7 +1794,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
|
|||||||
expanded = append(expanded, &fr)
|
expanded = append(expanded, &fr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) {
|
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||||
expanded = addNativeSSHRule(base, expanded)
|
expanded = addNativeSSHRule(base, expanded)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_GetActiveGroupUsers(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
expected map[string][]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "all users are active",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1", "group2"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group2", "group3"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user3": {
|
||||||
|
Id: "user3",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user1", "user3"},
|
||||||
|
"group2": {"user1", "user2"},
|
||||||
|
"group3": {"user2"},
|
||||||
|
"": {"user1", "user2", "user3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "some users are blocked",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1", "group2"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group2", "group3"},
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
"user3": {
|
||||||
|
Id: "user3",
|
||||||
|
AutoGroups: []string{"group1", "group3"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user1", "user3"},
|
||||||
|
"group2": {"user1"},
|
||||||
|
"group3": {"user3"},
|
||||||
|
"": {"user1", "user3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all users are blocked",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group2"},
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user with no auto groups",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user2"},
|
||||||
|
"": {"user1", "user2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty account",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple users in same group",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user3": {
|
||||||
|
Id: "user3",
|
||||||
|
AutoGroups: []string{"group1"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user1", "user2", "user3"},
|
||||||
|
"": {"user1", "user2", "user3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user in multiple groups with blocked users",
|
||||||
|
account: &Account{
|
||||||
|
Users: map[string]*User{
|
||||||
|
"user1": {
|
||||||
|
Id: "user1",
|
||||||
|
AutoGroups: []string{"group1", "group2", "group3"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
"user2": {
|
||||||
|
Id: "user2",
|
||||||
|
AutoGroups: []string{"group1", "group2"},
|
||||||
|
Blocked: true,
|
||||||
|
},
|
||||||
|
"user3": {
|
||||||
|
Id: "user3",
|
||||||
|
AutoGroups: []string{"group3"},
|
||||||
|
Blocked: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: map[string][]string{
|
||||||
|
"group1": {"user1"},
|
||||||
|
"group2": {"user1"},
|
||||||
|
"group3": {"user1", "user3"},
|
||||||
|
"": {"user1", "user3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
// Check that the number of groups matches
|
||||||
|
assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
|
||||||
|
|
||||||
|
// Check each group's users
|
||||||
|
for groupID, expectedUsers := range tt.expected {
|
||||||
|
actualUsers, exists := result[groupID]
|
||||||
|
assert.True(t, exists, "group %s should exist in result", groupID)
|
||||||
|
assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure no extra groups in result
|
||||||
|
for groupID := range result {
|
||||||
|
_, exists := tt.expected[groupID]
|
||||||
|
assert.True(t, exists, "unexpected group %s in result", groupID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ type NetworkMap struct {
|
|||||||
FirewallRules []*FirewallRule
|
FirewallRules []*FirewallRule
|
||||||
RoutesFirewallRules []*RouteFirewallRule
|
RoutesFirewallRules []*RouteFirewallRule
|
||||||
ForwardingRules []*ForwardingRule
|
ForwardingRules []*ForwardingRule
|
||||||
|
AuthorizedUsers map[string]map[string]struct{}
|
||||||
|
EnableSSH bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
|||||||
b.Run("old builder", func(b *testing.B) {
|
b.Run("old builder", func(b *testing.B) {
|
||||||
for range b.N {
|
for range b.N {
|
||||||
for _, peerID := range peerIDs {
|
for _, peerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -201,7 +201,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -320,7 +320,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
|||||||
b.Run("old builder after add", func(b *testing.B) {
|
b.Run("old builder after add", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -395,7 +395,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -550,7 +550,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
|||||||
b.Run("old builder after add", func(b *testing.B) {
|
b.Run("old builder after add", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -604,7 +604,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -730,7 +730,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
|||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
normalizeAndSortNetworkMap(networkMap)
|
||||||
|
|
||||||
@@ -847,7 +847,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
|||||||
b.Run("old builder after delete", func(b *testing.B) {
|
b.Run("old builder after delete", func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, testingPeerID := range peerIDs {
|
for _, testingPeerID := range peerIDs {
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ const (
|
|||||||
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
|
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
|
||||||
// PolicyRuleProtocolICMP type of traffic
|
// PolicyRuleProtocolICMP type of traffic
|
||||||
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
|
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
|
||||||
|
// PolicyRuleProtocolNetbirdSSH type of traffic
|
||||||
|
PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
|
|||||||
protocol = PolicyRuleProtocolUDP
|
protocol = PolicyRuleProtocolUDP
|
||||||
case "icmp":
|
case "icmp":
|
||||||
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
|
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
|
||||||
|
case "netbird-ssh":
|
||||||
|
return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
|
||||||
default:
|
default:
|
||||||
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
|
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,12 @@ type PolicyRule struct {
|
|||||||
|
|
||||||
// PortRanges a list of port ranges.
|
// PortRanges a list of port ranges.
|
||||||
PortRanges []RulePortRange `gorm:"serializer:json"`
|
PortRanges []RulePortRange `gorm:"serializer:json"`
|
||||||
|
|
||||||
|
// AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
|
||||||
|
AuthorizedGroups map[string][]string `gorm:"serializer:json"`
|
||||||
|
|
||||||
|
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
|
||||||
|
AuthorizedUser string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy returns a copy of a policy rule
|
// Copy returns a copy of a policy rule
|
||||||
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
|||||||
Protocol: pm.Protocol,
|
Protocol: pm.Protocol,
|
||||||
Ports: make([]string, len(pm.Ports)),
|
Ports: make([]string, len(pm.Ports)),
|
||||||
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
||||||
|
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
|
||||||
|
AuthorizedUser: pm.AuthorizedUser,
|
||||||
}
|
}
|
||||||
copy(rule.Destinations, pm.Destinations)
|
copy(rule.Destinations, pm.Destinations)
|
||||||
copy(rule.Sources, pm.Sources)
|
copy(rule.Sources, pm.Sources)
|
||||||
copy(rule.Ports, pm.Ports)
|
copy(rule.Ports, pm.Ports)
|
||||||
copy(rule.PortRanges, pm.PortRanges)
|
copy(rule.PortRanges, pm.PortRanges)
|
||||||
|
for k, v := range pm.AuthorizedGroups {
|
||||||
|
rule.AuthorizedGroups[k] = make([]string, len(v))
|
||||||
|
copy(rule.AuthorizedGroups[k], v)
|
||||||
|
}
|
||||||
return rule
|
return rule
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -523,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
_, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||||
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
|
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if userHadPeers {
|
updateAccountPeers = true
|
||||||
updateAccountPeers = true
|
|
||||||
}
|
|
||||||
|
|
||||||
err = transaction.SaveUser(ctx, updatedUser)
|
err = transaction.SaveUser(ctx, updatedUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -581,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.GroupsPropagationEnabled && updateAccountPeers {
|
if updateAccountPeers {
|
||||||
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
|
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1379,11 +1379,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
|||||||
updateManager.CloseChannel(context.Background(), peer1.ID)
|
updateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Creating a new regular user should not update account peers and not send peer update
|
// Creating a new regular user should send peer update (as users are not filtered yet)
|
||||||
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -1402,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// updating user with no linked peers should not update account peers and not send peer update
|
// updating user with no linked peers should update account peers and send peer update (as users are not filtered yet)
|
||||||
t.Run("updating user with no linked peers", func(t *testing.T) {
|
t.Run("updating user with no linked peers", func(t *testing.T) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -488,6 +488,8 @@ components:
|
|||||||
description: Indicates whether the peer is ephemeral or not
|
description: Indicates whether the peer is ephemeral or not
|
||||||
type: boolean
|
type: boolean
|
||||||
example: false
|
example: false
|
||||||
|
local_flags:
|
||||||
|
$ref: '#/components/schemas/PeerLocalFlags'
|
||||||
required:
|
required:
|
||||||
- city_name
|
- city_name
|
||||||
- connected
|
- connected
|
||||||
@@ -514,6 +516,49 @@ components:
|
|||||||
- serial_number
|
- serial_number
|
||||||
- extra_dns_labels
|
- extra_dns_labels
|
||||||
- ephemeral
|
- ephemeral
|
||||||
|
PeerLocalFlags:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
rosenpass_enabled:
|
||||||
|
description: Indicates whether Rosenpass is enabled on this peer
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
|
rosenpass_permissive:
|
||||||
|
description: Indicates whether Rosenpass is in permissive mode or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
server_ssh_allowed:
|
||||||
|
description: Indicates whether SSH access this peer is allowed or not
|
||||||
|
type: boolean
|
||||||
|
example: true
|
||||||
|
disable_client_routes:
|
||||||
|
description: Indicates whether client routes are disabled on this peer or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
disable_server_routes:
|
||||||
|
description: Indicates whether server routes are disabled on this peer or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
disable_dns:
|
||||||
|
description: Indicates whether DNS management is disabled on this peer or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
disable_firewall:
|
||||||
|
description: Indicates whether firewall management is disabled on this peer or not
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
block_lan_access:
|
||||||
|
description: Indicates whether LAN access is blocked on this peer when used as a routing peer
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
block_inbound:
|
||||||
|
description: Indicates whether inbound traffic is blocked on this peer
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
|
lazy_connection_enabled:
|
||||||
|
description: Indicates whether lazy connection is enabled on this peer
|
||||||
|
type: boolean
|
||||||
|
example: false
|
||||||
PeerTemporaryAccessRequest:
|
PeerTemporaryAccessRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
@@ -936,7 +981,7 @@ components:
|
|||||||
protocol:
|
protocol:
|
||||||
description: Policy rule type of the traffic
|
description: Policy rule type of the traffic
|
||||||
type: string
|
type: string
|
||||||
enum: ["all", "tcp", "udp", "icmp"]
|
enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"]
|
||||||
example: "tcp"
|
example: "tcp"
|
||||||
ports:
|
ports:
|
||||||
description: Policy rule affected ports
|
description: Policy rule affected ports
|
||||||
@@ -949,6 +994,14 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/RulePortRange'
|
$ref: '#/components/schemas/RulePortRange'
|
||||||
|
authorized_groups:
|
||||||
|
description: Map of user group ids to a list of local users
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
example: "group1"
|
||||||
required:
|
required:
|
||||||
- name
|
- name
|
||||||
- enabled
|
- enabled
|
||||||
|
|||||||
@@ -130,10 +130,11 @@ const (
|
|||||||
|
|
||||||
// Defines values for PolicyRuleProtocol.
|
// Defines values for PolicyRuleProtocol.
|
||||||
const (
|
const (
|
||||||
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
|
PolicyRuleProtocolAll PolicyRuleProtocol = "all"
|
||||||
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
|
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
|
||||||
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
|
PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh"
|
||||||
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
|
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
|
||||||
|
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for PolicyRuleMinimumAction.
|
// Defines values for PolicyRuleMinimumAction.
|
||||||
@@ -144,10 +145,11 @@ const (
|
|||||||
|
|
||||||
// Defines values for PolicyRuleMinimumProtocol.
|
// Defines values for PolicyRuleMinimumProtocol.
|
||||||
const (
|
const (
|
||||||
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
|
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
|
||||||
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
|
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
|
||||||
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
|
PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh"
|
||||||
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
|
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
|
||||||
|
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for PolicyRuleUpdateAction.
|
// Defines values for PolicyRuleUpdateAction.
|
||||||
@@ -158,10 +160,11 @@ const (
|
|||||||
|
|
||||||
// Defines values for PolicyRuleUpdateProtocol.
|
// Defines values for PolicyRuleUpdateProtocol.
|
||||||
const (
|
const (
|
||||||
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
|
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
|
||||||
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
|
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
|
||||||
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
|
PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh"
|
||||||
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
|
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
|
||||||
|
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines values for ResourceType.
|
// Defines values for ResourceType.
|
||||||
@@ -1077,7 +1080,8 @@ type Peer struct {
|
|||||||
LastLogin time.Time `json:"last_login"`
|
LastLogin time.Time `json:"last_login"`
|
||||||
|
|
||||||
// LastSeen Last time peer connected to Netbird's management service
|
// LastSeen Last time peer connected to Netbird's management service
|
||||||
LastSeen time.Time `json:"last_seen"`
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
|
||||||
|
|
||||||
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
||||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||||
@@ -1167,7 +1171,8 @@ type PeerBatch struct {
|
|||||||
LastLogin time.Time `json:"last_login"`
|
LastLogin time.Time `json:"last_login"`
|
||||||
|
|
||||||
// LastSeen Last time peer connected to Netbird's management service
|
// LastSeen Last time peer connected to Netbird's management service
|
||||||
LastSeen time.Time `json:"last_seen"`
|
LastSeen time.Time `json:"last_seen"`
|
||||||
|
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
|
||||||
|
|
||||||
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
|
||||||
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
|
||||||
@@ -1197,6 +1202,39 @@ type PeerBatch struct {
|
|||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerLocalFlags defines model for PeerLocalFlags.
|
||||||
|
type PeerLocalFlags struct {
|
||||||
|
// BlockInbound Indicates whether inbound traffic is blocked on this peer
|
||||||
|
BlockInbound *bool `json:"block_inbound,omitempty"`
|
||||||
|
|
||||||
|
// BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer
|
||||||
|
BlockLanAccess *bool `json:"block_lan_access,omitempty"`
|
||||||
|
|
||||||
|
// DisableClientRoutes Indicates whether client routes are disabled on this peer or not
|
||||||
|
DisableClientRoutes *bool `json:"disable_client_routes,omitempty"`
|
||||||
|
|
||||||
|
// DisableDns Indicates whether DNS management is disabled on this peer or not
|
||||||
|
DisableDns *bool `json:"disable_dns,omitempty"`
|
||||||
|
|
||||||
|
// DisableFirewall Indicates whether firewall management is disabled on this peer or not
|
||||||
|
DisableFirewall *bool `json:"disable_firewall,omitempty"`
|
||||||
|
|
||||||
|
// DisableServerRoutes Indicates whether server routes are disabled on this peer or not
|
||||||
|
DisableServerRoutes *bool `json:"disable_server_routes,omitempty"`
|
||||||
|
|
||||||
|
// LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer
|
||||||
|
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
|
||||||
|
|
||||||
|
// RosenpassEnabled Indicates whether Rosenpass is enabled on this peer
|
||||||
|
RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"`
|
||||||
|
|
||||||
|
// RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not
|
||||||
|
RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"`
|
||||||
|
|
||||||
|
// ServerSshAllowed Indicates whether SSH access this peer is allowed or not
|
||||||
|
ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// PeerMinimum defines model for PeerMinimum.
|
// PeerMinimum defines model for PeerMinimum.
|
||||||
type PeerMinimum struct {
|
type PeerMinimum struct {
|
||||||
// Id Peer ID
|
// Id Peer ID
|
||||||
@@ -1349,6 +1387,9 @@ type PolicyRule struct {
|
|||||||
// Action Policy rule accept or drops packets
|
// Action Policy rule accept or drops packets
|
||||||
Action PolicyRuleAction `json:"action"`
|
Action PolicyRuleAction `json:"action"`
|
||||||
|
|
||||||
|
// AuthorizedGroups Map of user group ids to a list of local users
|
||||||
|
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||||
|
|
||||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||||
Bidirectional bool `json:"bidirectional"`
|
Bidirectional bool `json:"bidirectional"`
|
||||||
|
|
||||||
@@ -1393,6 +1434,9 @@ type PolicyRuleMinimum struct {
|
|||||||
// Action Policy rule accept or drops packets
|
// Action Policy rule accept or drops packets
|
||||||
Action PolicyRuleMinimumAction `json:"action"`
|
Action PolicyRuleMinimumAction `json:"action"`
|
||||||
|
|
||||||
|
// AuthorizedGroups Map of user group ids to a list of local users
|
||||||
|
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||||
|
|
||||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||||
Bidirectional bool `json:"bidirectional"`
|
Bidirectional bool `json:"bidirectional"`
|
||||||
|
|
||||||
@@ -1426,6 +1470,9 @@ type PolicyRuleUpdate struct {
|
|||||||
// Action Policy rule accept or drops packets
|
// Action Policy rule accept or drops packets
|
||||||
Action PolicyRuleUpdateAction `json:"action"`
|
Action PolicyRuleUpdateAction `json:"action"`
|
||||||
|
|
||||||
|
// AuthorizedGroups Map of user group ids to a list of local users
|
||||||
|
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
|
||||||
|
|
||||||
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
|
||||||
Bidirectional bool `json:"bidirectional"`
|
Bidirectional bool `json:"bidirectional"`
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -332,6 +332,24 @@ message NetworkMap {
|
|||||||
bool routesFirewallRulesIsEmpty = 11;
|
bool routesFirewallRulesIsEmpty = 11;
|
||||||
|
|
||||||
repeated ForwardingRule forwardingRules = 12;
|
repeated ForwardingRule forwardingRules = 12;
|
||||||
|
|
||||||
|
// SSHAuth represents SSH authorization configuration
|
||||||
|
SSHAuth sshAuth = 13;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SSHAuth {
|
||||||
|
// UserIDClaim is the JWT claim to be used to get the users ID
|
||||||
|
string UserIDClaim = 1;
|
||||||
|
|
||||||
|
// AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH
|
||||||
|
repeated bytes AuthorizedUsers = 2;
|
||||||
|
|
||||||
|
// MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list
|
||||||
|
map<string, MachineUserIndexes> machine_users = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MachineUserIndexes {
|
||||||
|
repeated uint32 indexes = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemotePeerConfig represents a configuration of a remote peer.
|
// RemotePeerConfig represents a configuration of a remote peer.
|
||||||
|
|||||||
28
shared/sshauth/userhash.go
Normal file
28
shared/sshauth/userhash.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package sshauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/blake2b"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserIDHash represents a hashed user ID (BLAKE2b-128)
|
||||||
|
type UserIDHash [16]byte
|
||||||
|
|
||||||
|
// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value
|
||||||
|
// This function must produce the same hash on both client and management server
|
||||||
|
func HashUserID(userID string) (UserIDHash, error) {
|
||||||
|
hash, err := blake2b.New(16, nil)
|
||||||
|
if err != nil {
|
||||||
|
return UserIDHash{}, err
|
||||||
|
}
|
||||||
|
hash.Write([]byte(userID))
|
||||||
|
var result UserIDHash
|
||||||
|
copy(result[:], hash.Sum(nil))
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the hexadecimal string representation of the hash
|
||||||
|
func (h UserIDHash) String() string {
|
||||||
|
return hex.EncodeToString(h[:])
|
||||||
|
}
|
||||||
210
shared/sshauth/userhash_test.go
Normal file
210
shared/sshauth/userhash_test.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package sshauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHashUserID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple user ID",
|
||||||
|
userID: "user@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID format",
|
||||||
|
userID: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "numeric ID",
|
||||||
|
userID: "12345",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
userID: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters",
|
||||||
|
userID: "user+test@domain.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode characters",
|
||||||
|
userID: "用户@example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hash, err := HashUserID(tt.userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("HashUserID() error = %v, want nil", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify hash is non-zero for non-empty inputs
|
||||||
|
if tt.userID != "" && hash == [16]byte{} {
|
||||||
|
t.Errorf("HashUserID() returned zero hash for non-empty input")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashUserID_Consistency(t *testing.T) {
|
||||||
|
userID := "test@example.com"
|
||||||
|
|
||||||
|
hash1, err1 := HashUserID(userID)
|
||||||
|
if err1 != nil {
|
||||||
|
t.Fatalf("First HashUserID() error = %v", err1)
|
||||||
|
}
|
||||||
|
|
||||||
|
hash2, err2 := HashUserID(userID)
|
||||||
|
if err2 != nil {
|
||||||
|
t.Fatalf("Second HashUserID() error = %v", err2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hash1 != hash2 {
|
||||||
|
t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashUserID_Uniqueness(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
userID1 string
|
||||||
|
userID2 string
|
||||||
|
}{
|
||||||
|
{"user1@example.com", "user2@example.com"},
|
||||||
|
{"alice@domain.com", "bob@domain.com"},
|
||||||
|
{"test", "test1"},
|
||||||
|
{"", "a"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
hash1, err1 := HashUserID(tt.userID1)
|
||||||
|
if err1 != nil {
|
||||||
|
t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1)
|
||||||
|
}
|
||||||
|
|
||||||
|
hash2, err2 := HashUserID(tt.userID2)
|
||||||
|
if err2 != nil {
|
||||||
|
t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hash1 == hash2 {
|
||||||
|
t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserIDHash_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hash UserIDHash
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "zero hash",
|
||||||
|
hash: [16]byte{},
|
||||||
|
expected: "00000000000000000000000000000000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "small value",
|
||||||
|
hash: [16]byte{15: 0xff},
|
||||||
|
expected: "000000000000000000000000000000ff",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large value",
|
||||||
|
hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe},
|
||||||
|
expected: "0000000000000000deadbeefcafebabe",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "max value",
|
||||||
|
hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||||
|
expected: "ffffffffffffffffffffffffffffffff",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.hash.String()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserIDHash_String_Length(t *testing.T) {
|
||||||
|
// Test that String() always returns 32 hex characters (16 bytes * 2)
|
||||||
|
userID := "test@example.com"
|
||||||
|
hash, err := HashUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HashUserID() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := hash.String()
|
||||||
|
if len(result) != 32 {
|
||||||
|
t.Errorf("UserIDHash.String() length = %d, want 32", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's valid hex
|
||||||
|
for i, c := range result {
|
||||||
|
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||||
|
t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashUserID_KnownValues(t *testing.T) {
|
||||||
|
// Test with known BLAKE2b-128 values to ensure correct implementation
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userID string
|
||||||
|
expected UserIDHash
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
userID: "",
|
||||||
|
// BLAKE2b-128 of empty string
|
||||||
|
expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single character 'a'",
|
||||||
|
userID: "a",
|
||||||
|
// BLAKE2b-128 of "a"
|
||||||
|
expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hash, err := HashUserID(tt.userID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("HashUserID() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hash != tt.expected {
|
||||||
|
t.Errorf("HashUserID(%q) = %x, want %x",
|
||||||
|
tt.userID, hash, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHashUserID(b *testing.B) {
|
||||||
|
userID := "user@example.com"
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = HashUserID(userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUserIDHash_String(b *testing.B) {
|
||||||
|
hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe})
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = hash.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user