[client, management] Feature/ssh fine grained access (#4969)

Add fine-grained SSH access control with authorized users/groups
This commit is contained in:
Zoltan Papp
2025-12-29 12:50:41 +01:00
committed by GitHub
parent 73201c4f3e
commit 67f7b2404e
32 changed files with 2345 additions and 512 deletions

View File

@@ -1151,6 +1151,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil { if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH client config: %v", err) log.Warnf("failed to update SSH client config: %v", err)
} }
e.updateSSHServerAuth(networkMap.GetSshAuth())
} }
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store

View File

@@ -11,15 +11,18 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config" sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server" sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
) )
type sshServer interface { type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error Start(ctx context.Context, addr netip.AddrPort) error
Stop() error Stop() error
GetStatus() (bool, []sshserver.SessionInfo) GetStatus() (bool, []sshserver.SessionInfo)
UpdateSSHAuth(config *sshauth.Config)
} }
func (e *Engine) setupSSHPortRedirection() error { func (e *Engine) setupSSHPortRedirection() error {
@@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
return sshServer.GetStatus() return sshServer.GetStatus()
} }
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
if sshAuth == nil {
return
}
if e.sshServer == nil {
return
}
protoUsers := sshAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range sshAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
// Update SSH server with new authorization configuration
authConfig := &sshauth.Config{
UserIDClaim: sshAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
}
e.sshServer.UpdateSSHAuth(authConfig)
}

184
client/ssh/auth/auth.go Normal file
View File

@@ -0,0 +1,184 @@
package auth
import (
"errors"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
const (
// DefaultUserIDClaim is the default JWT claim used to extract user IDs
DefaultUserIDClaim = "sub"
// Wildcard is a special user ID that matches all users
Wildcard = "*"
)
var (
ErrEmptyUserID = errors.New("JWT user ID is empty")
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
)
// Authorizer handles SSH fine-grained access control authorization
type Authorizer struct {
// UserIDClaim is the JWT claim to extract the user ID from
userIDClaim string
// authorizedUsers is a list of hashed user IDs authorized to access this peer
authorizedUsers []sshuserhash.UserIDHash
// machineUsers maps OS login usernames to lists of authorized user indexes
machineUsers map[string][]uint32
// mu protects the list of users
mu sync.RWMutex
}
// Config contains configuration for the SSH authorizer
type Config struct {
// UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
UserIDClaim string
// AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
AuthorizedUsers []sshuserhash.UserIDHash
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
// If a user wants to login as a specific OS user, their index must be in the corresponding list
MachineUsers map[string][]uint32
}
// NewAuthorizer creates a new SSH authorizer with empty configuration
func NewAuthorizer() *Authorizer {
a := &Authorizer{
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
}
return a
}
// Update updates the authorizer configuration with new values
func (a *Authorizer) Update(config *Config) {
a.mu.Lock()
defer a.mu.Unlock()
if config == nil {
// Clear authorization
a.userIDClaim = DefaultUserIDClaim
a.authorizedUsers = []sshuserhash.UserIDHash{}
a.machineUsers = make(map[string][]uint32)
log.Info("SSH authorization cleared")
return
}
userIDClaim := config.UserIDClaim
if userIDClaim == "" {
userIDClaim = DefaultUserIDClaim
}
a.userIDClaim = userIDClaim
// Store authorized users list
a.authorizedUsers = config.AuthorizedUsers
// Store machine users mapping
machineUsers := make(map[string][]uint32)
for osUser, indexes := range config.MachineUsers {
if len(indexes) > 0 {
machineUsers[osUser] = indexes
}
}
a.machineUsers = machineUsers
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
len(config.AuthorizedUsers), len(machineUsers))
}
// Authorize validates if a user is authorized to login as the specified OS user
// Returns nil if authorized, or an error describing why authorization failed
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
if jwtUserID == "" {
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
return ErrEmptyUserID
}
// Hash the JWT user ID for comparison
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
if err != nil {
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
return fmt.Errorf("failed to hash user ID: %w", err)
}
a.mu.RLock()
defer a.mu.RUnlock()
// Find the index of this user in the authorized list
userIndex, found := a.findUserIndex(hashedUserID)
if !found {
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
return ErrUserNotAuthorized
}
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
}
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
// Checks wildcard mapping first, then specific OS user mappings
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
return nil
}
}
// Check for specific OS username mapping
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
if !hasMachineUserMapping {
// No mapping for this OS user - deny by default (fail closed)
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
return ErrNoMachineUserMapping
}
// Check if user's index is in the allowed indexes for this specific OS user
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
return ErrUserNotMappedToOSUser
}
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
return nil
}
// GetUserIDClaim returns the JWT claim name used to extract user IDs
func (a *Authorizer) GetUserIDClaim() string {
a.mu.RLock()
defer a.mu.RUnlock()
return a.userIDClaim
}
// findUserIndex finds the index of a hashed user ID in the authorized users list
// Returns the index and true if found, 0 and false if not found
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
for i, id := range a.authorizedUsers {
if id == hashedUserID {
return i, true
}
}
return 0, false
}
// isIndexInList checks if an index exists in a list of indexes
func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
for _, idx := range indexes {
if idx == index {
return true
}
}
return false
}

View File

@@ -0,0 +1,612 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/sshauth"
)
func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list with one user
authorizedUserHash, err := sshauth.HashUserID("authorized-user")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Try to authorize a different user
err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
}
authorizer.Update(config)
// All attempts should fail when no machine user mappings exist (fail closed)
err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should access root and admin
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "admin")
assert.NoError(t, err)
// user2 (index 1) should access root and postgres
err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user3 (index 2) should access postgres
err = authorizer.Authorize("user3", "postgres")
assert.NoError(t, err)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should NOT access postgres
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 (index 1) should NOT access admin
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access root
err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access admin
err = authorizer.Authorize("user3", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
}
func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only root is mapped
},
}
authorizer.Update(config)
// user1 should NOT access an unmapped OS user (fail closed)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Empty user ID should fail
err = authorizer.Authorize("", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrEmptyUserID)
}
func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up multiple authorized users
userHashes := make([]sshauth.UserIDHash, 10)
for i := 0; i < 10; i++ {
hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
require.NoError(t, err)
userHashes[i] = hash
}
// Create machine user mapping for all users
rootIndexes := make([]uint32, 10)
for i := 0; i < 10; i++ {
rootIndexes[i] = uint32(i)
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": rootIndexes,
},
}
authorizer.Update(config)
// All users should be authorized for root
for i := 0; i < 10; i++ {
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
assert.NoError(t, err, "user%d should be authorized", i)
}
// User not in list should fail
err := authorizer.Authorize("unknown-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
authorizer := NewAuthorizer()
// Set up initial configuration
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{"root": {0}},
}
authorizer.Update(config)
// user1 should be authorized
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Clear configuration
authorizer.Update(nil)
// user1 should no longer be authorized
err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Machine users with empty index lists should be filtered out
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
"postgres": {}, // empty list - should be filtered out
"admin": nil, // nil list - should be filtered out
},
}
authorizer.Update(config)
// root should work
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// postgres should fail (no mapping)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// admin should fail (no mapping)
err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Set up with custom user ID claim
user1Hash, err := sshauth.HashUserID("user@example.com")
require.NoError(t, err)
config := &Config{
UserIDClaim: "email",
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
},
}
authorizer.Update(config)
// Verify the custom claim is set
assert.Equal(t, "email", authorizer.GetUserIDClaim())
// Authorize with email as user ID
err = authorizer.Authorize("user@example.com", "root")
assert.NoError(t, err)
}
func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Verify default claim
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
assert.Equal(t, "sub", authorizer.GetUserIDClaim())
// Set up with empty user ID claim (should use default)
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: "", // empty - should use default
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Should fall back to default
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
}
func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
authorizer := NewAuthorizer()
// Create a large authorized users list
const numUsers = 1000
userHashes := make([]sshauth.UserIDHash, numUsers)
for i := 0; i < numUsers; i++ {
hash, err := sshauth.HashUserID("user" + string(rune(i)))
require.NoError(t, err)
userHashes[i] = hash
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": {0, 500, 999}, // first, middle, and last user
},
}
authorizer.Update(config)
// First user should have access
err := authorizer.Authorize("user"+string(rune(0)), "root")
assert.NoError(t, err)
// Middle user should have access
err = authorizer.Authorize("user"+string(rune(500)), "root")
assert.NoError(t, err)
// Last user should have access
err = authorizer.Authorize("user"+string(rune(999)), "root")
assert.NoError(t, err)
// User not in mapping should NOT have access
err = authorizer.Authorize("user"+string(rune(100)), "root")
assert.Error(t, err)
}
func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1},
},
}
authorizer.Update(config)
// Test concurrent authorization calls (should be safe to read concurrently)
const numGoroutines = 100
errChan := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(idx int) {
user := "user1"
if idx%2 == 0 {
user = "user2"
}
err := authorizer.Authorize(user, "root")
errChan <- err
}(i)
}
// Wait for all goroutines to complete and collect errors
for i := 0; i < numGoroutines; i++ {
err := <-errChan
assert.NoError(t, err)
}
}
func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
// Configure with wildcard - all authorized users can access any OS user
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1, 2}, // wildcard with all user indexes
},
}
authorizer.Update(config)
// All authorized users should be able to access any OS user
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "admin")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "ubuntu")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "nginx")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "docker")
assert.NoError(t, err)
}
func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Configure with wildcard
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"*": {0},
},
}
authorizer.Update(config)
// user1 should have access
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Unauthorized user should still be denied even with wildcard
err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with both wildcard and specific mappings
// Wildcard takes precedence for users in the wildcard index list
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1}, // wildcard for both users
"root": {0}, // specific mapping that would normally restrict to user1 only
},
}
authorizer.Update(config)
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
// Both users should be able to access any other OS user via wildcard
err = authorizer.Authorize("user1", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "admin")
assert.NoError(t, err)
}
func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure WITHOUT wildcard - only specific mappings
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only user1
"postgres": {1}, // only user2
},
}
authorizer.Update(config)
// user1 can access root
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// user2 can access postgres
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user1 cannot access postgres
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 cannot access root
err = authorizer.Authorize("user2", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// Neither can access unmapped OS users
err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
// This test covers the scenario where wildcard exists with limited indexes.
// Only users whose indexes are in the wildcard list can access any OS user via wildcard.
// Other users can only access OS users they are explicitly mapped to.
authorizer := NewAuthorizer()
// Create two authorized user hashes (simulating the base64-encoded hashes in the config)
wasmHash, err := sshauth.HashUserID("wasm")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with wildcard having only index 0, and specific mappings for other OS users
config := &Config{
UserIDClaim: "sub",
AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0}, // wildcard with only index 0 - only wasm has wildcard access
"alice": {1}, // specific mapping for user2
"bob": {1}, // specific mapping for user2
},
}
authorizer.Update(config)
// wasm (index 0) should access any OS user via wildcard
err = authorizer.Authorize("wasm", "root")
assert.NoError(t, err, "wasm should access root via wildcard")
err = authorizer.Authorize("wasm", "alice")
assert.NoError(t, err, "wasm should access alice via wildcard")
err = authorizer.Authorize("wasm", "bob")
assert.NoError(t, err, "wasm should access bob via wildcard")
err = authorizer.Authorize("wasm", "postgres")
assert.NoError(t, err, "wasm should access postgres via wildcard")
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
err = authorizer.Authorize("user2", "alice")
assert.NoError(t, err, "user2 should access alice via explicit mapping")
err = authorizer.Authorize("user2", "bob")
assert.NoError(t, err, "user2 should access bob via explicit mapping")
err = authorizer.Authorize("user2", "root")
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "postgres")
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// Unauthorized user should still be denied
err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}

View File

@@ -27,9 +27,11 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server" "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil" "github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) {
sshServer := server.New(serverConfig) sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true) sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer) sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }() defer func() { _ = sshServer.Stop() }()
@@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) {
mockDaemon.setHostKey(host, hostPubKey) mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience) validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken) mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil) proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err) require.NoError(t, err)
clientConn, proxyConn := net.Pipe() clientConn, proxyConn := net.Pipe()
@@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
return privateKey, jwksJSON return privateKey, jwksJSON
} }
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string { func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper() t.Helper()
claims := jwt.MapClaims{ claims := jwt.MapClaims{
"iss": issuer, "iss": issuer,
"aud": audience, "aud": audience,
"sub": "test-user", "sub": user,
"exp": time.Now().Add(time.Hour).Unix(), "exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(), "iat": time.Now().Unix(),
} }

View File

@@ -23,10 +23,12 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/client" "github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil" "github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
) )
func TestJWTEnforcement(t *testing.T) { func TestJWTEnforcement(t *testing.T) {
@@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) {
tc.setupServer(server) tc.setupServer(server)
} }
// Always set up authorization for test-user to ensure tests fail at JWT validation stage
testUserHash, err := sshuserhash.HashUserID("test-user")
require.NoError(t, err)
// Get current OS username for machine user mapping
currentUser := testutil.GetTestUsername(t)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
currentUser: {0}, // Allow test-user (index 0) to access current OS user
},
}
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server) serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop()) defer require.NoError(t, server.Stop())

View File

@@ -21,6 +21,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt" "github.com/netbirdio/netbird/shared/auth/jwt"
@@ -138,6 +139,8 @@ type Server struct {
jwtExtractor *jwt.ClaimsExtractor jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig jwtConfig *JWTConfig
authorizer *sshauth.Authorizer
suSupportsPty bool suSupportsPty bool
loginIsUtilLinux bool loginIsUtilLinux bool
} }
@@ -179,6 +182,7 @@ func New(config *Config) *Server {
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil, jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT, jwtConfig: config.JWT,
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
} }
return s return s
@@ -320,6 +324,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr s.wgAddress = addr
} }
// UpdateSSHAuth updates the SSH fine-grained access control configuration
// This should be called when network map updates include new SSH auth configuration
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
s.mu.Lock()
defer s.mu.Unlock()
// Reset JWT validator/extractor to pick up new userIDClaim
s.jwtValidator = nil
s.jwtExtractor = nil
s.authorizer.Update(config)
}
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized // ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error { func (s *Server) ensureJWTValidator() error {
s.mu.RLock() s.mu.RLock()
@@ -328,6 +345,7 @@ func (s *Server) ensureJWTValidator() error {
return nil return nil
} }
config := s.jwtConfig config := s.jwtConfig
authorizer := s.authorizer
s.mu.RUnlock() s.mu.RUnlock()
if config == nil { if config == nil {
@@ -343,9 +361,16 @@ func (s *Server) ensureJWTValidator() error {
true, true,
) )
extractor := jwt.NewClaimsExtractor( // Use custom userIDClaim from authorizer if available
extractorOptions := []jwt.ClaimsExtractorOption{
jwt.WithAudience(config.Audience), jwt.WithAudience(config.Audience),
) }
if authorizer.GetUserIDClaim() != "" {
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
}
extractor := jwt.NewClaimsExtractor(extractorOptions...)
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -493,29 +518,41 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
} }
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool { func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
osUsername := ctx.User()
remoteAddr := ctx.RemoteAddr()
if err := s.ensureJWTValidator(); err != nil { if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
return false return false
} }
token, err := s.validateJWTToken(password) token, err := s.validateJWTToken(password)
if err != nil { if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
return false return false
} }
userAuth, err := s.extractAndValidateUser(token) userAuth, err := s.extractAndValidateUser(token)
if err != nil { if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err) log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
return false return false
} }
key := newAuthKey(ctx.User(), ctx.RemoteAddr()) s.mu.RLock()
authorizer := s.authorizer
s.mu.RUnlock()
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
return false
}
key := newAuthKey(osUsername, remoteAddr)
s.mu.Lock() s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId s.pendingAuthJWT[key] = userAuth.UserId
s.mu.Unlock() s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr()) log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
return true return true
} }

View File

@@ -178,6 +178,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
customZone := account.GetPeersCustomZone(ctx, dnsDomain) customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
@@ -224,7 +225,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics) remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else { } else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} }
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -320,6 +321,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
customZone := account.GetPeersCustomZone(ctx, dnsDomain) customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
postureChecks, err := c.getPeerPostureChecks(account, peerId) postureChecks, err := c.getPeerPostureChecks(account, peerId)
if err != nil { if err != nil {
@@ -338,7 +340,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
if c.experimentalNetworkMap(accountId) { if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else { } else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} }
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -445,7 +447,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else { } else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics) networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers())
} }
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -811,7 +813,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if c.experimentalNetworkMap(peer.AccountID) { if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
} else { } else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
} }
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -158,5 +158,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
} }
} }
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }

View File

@@ -6,7 +6,10 @@ import (
"net/url" "net/url"
"strings" "strings"
log "github.com/sirupsen/logrus"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
@@ -16,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/sshauth"
) )
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
@@ -84,15 +88,15 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
return nbConfig return nbConfig
} }
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig { func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, enableSSH bool) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size() netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName) fqdn := peer.FQDN(dnsName)
sshConfig := &proto.SSHConfig{ sshConfig := &proto.SSHConfig{
SshEnabled: peer.SSHEnabled, SshEnabled: peer.SSHEnabled || enableSSH,
} }
if peer.SSHEnabled { if sshConfig.SshEnabled {
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig) sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
} }
@@ -110,12 +114,12 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{ response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(), Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes), Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
}, },
Checks: toProtocolChecks(ctx, checks), Checks: toProtocolChecks(ctx, checks),
} }
@@ -151,9 +155,45 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.ForwardingRules = forwardingRules response.NetworkMap.ForwardingRules = forwardingRules
} }
if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
}
return response return response
} }
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
for machineUser, users := range authorizedUsers {
indexes := make([]uint32, 0, len(users))
for userID := range users {
idx, exists := userIDToIndex[userID]
if !exists {
hash, err := sshauth.HashUserID(userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
continue
}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
}
return hashedUsers, machineUsers
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers { for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{ dst = append(dst, &proto.RemotePeerConfig{

View File

@@ -635,7 +635,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
// if peer has reached this point then it has logged in // if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{ loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow), PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
Checks: toProtocolChecks(ctx, postureChecks), Checks: toProtocolChecks(ctx, postureChecks),
} }

View File

@@ -1456,21 +1456,19 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
} }
if settings.GroupsPropagationEnabled { removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups) if err != nil {
if err != nil { return err
return err }
}
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups) newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
if err != nil { if err != nil {
return err return err
} }
if removedGroupAffectsPeers || newGroupsAffectsPeers { if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
} }
return nil return nil

View File

@@ -397,7 +397,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }

View File

@@ -299,7 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
} }
@@ -369,6 +369,9 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
PortRanges: []types.RulePortRange{portRange}, PortRanges: []types.RulePortRange{portRange},
}}, }},
} }
if protocol == types.PolicyRuleProtocolNetbirdSSH {
policy.Rules[0].AuthorizedUser = userAuth.UserId
}
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true) _, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
if err != nil { if err != nil {
@@ -449,6 +452,18 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
SerialNumber: peer.Meta.SystemSerialNumber, SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled, InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral, Ephemeral: peer.Ephemeral,
LocalFlags: &api.PeerLocalFlags{
BlockInbound: &peer.Meta.Flags.BlockInbound,
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
DisableDns: &peer.Meta.Flags.DisableDNS,
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
},
} }
if !approved { if !approved {
@@ -463,7 +478,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
if osVersion == "" { if osVersion == "" {
osVersion = peer.Meta.Core osVersion = peer.Meta.Core
} }
return &api.PeerBatch{ return &api.PeerBatch{
CreatedAt: peer.CreatedAt, CreatedAt: peer.CreatedAt,
Id: peer.ID, Id: peer.ID,
@@ -492,6 +506,18 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
SerialNumber: peer.Meta.SystemSerialNumber, SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled, InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral, Ephemeral: peer.Ephemeral,
LocalFlags: &api.PeerLocalFlags{
BlockInbound: &peer.Meta.Flags.BlockInbound,
BlockLanAccess: &peer.Meta.Flags.BlockLANAccess,
DisableClientRoutes: &peer.Meta.Flags.DisableClientRoutes,
DisableDns: &peer.Meta.Flags.DisableDNS,
DisableFirewall: &peer.Meta.Flags.DisableFirewall,
DisableServerRoutes: &peer.Meta.Flags.DisableServerRoutes,
LazyConnectionEnabled: &peer.Meta.Flags.LazyConnectionEnabled,
RosenpassEnabled: &peer.Meta.Flags.RosenpassEnabled,
RosenpassPermissive: &peer.Meta.Flags.RosenpassPermissive,
ServerSshAllowed: &peer.Meta.Flags.ServerSSHAllowed,
},
} }
} }

View File

@@ -221,6 +221,8 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
pr.Protocol = types.PolicyRuleProtocolUDP pr.Protocol = types.PolicyRuleProtocolUDP
case api.PolicyRuleUpdateProtocolIcmp: case api.PolicyRuleUpdateProtocolIcmp:
pr.Protocol = types.PolicyRuleProtocolICMP pr.Protocol = types.PolicyRuleProtocolICMP
case api.PolicyRuleUpdateProtocolNetbirdSsh:
pr.Protocol = types.PolicyRuleProtocolNetbirdSSH
default: default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
return return
@@ -254,6 +256,17 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
} }
} }
if pr.Protocol == types.PolicyRuleProtocolNetbirdSSH && rule.AuthorizedGroups != nil && len(*rule.AuthorizedGroups) != 0 {
for _, sourceGroupID := range pr.Sources {
_, ok := (*rule.AuthorizedGroups)[sourceGroupID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "authorized group for netbird-ssh protocol should be specified for each source group"), w)
return
}
}
pr.AuthorizedGroups = *rule.AuthorizedGroups
}
// validate policy object // validate policy object
if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP { if pr.Protocol == types.PolicyRuleProtocolALL || pr.Protocol == types.PolicyRuleProtocolICMP {
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
@@ -380,6 +393,11 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
DestinationResource: r.DestinationResource.ToAPIResponse(), DestinationResource: r.DestinationResource.ToAPIResponse(),
} }
if len(r.AuthorizedGroups) != 0 {
authorizedGroupsCopy := r.AuthorizedGroups
rule.AuthorizedGroups = &authorizedGroupsCopy
}
if len(r.Ports) != 0 { if len(r.Ports) != 0 {
portsCopy := r.Ports portsCopy := r.Ports
rule.Ports = &portsCopy rule.Ports = &portsCopy

View File

@@ -91,7 +91,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap) aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap, account.GetActiveGroupUsers())
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@@ -1057,7 +1057,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
} }
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap) aclPeers, _, _, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap, account.GetActiveGroupUsers())
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peer.ID { if aclPeer.ID == peer.ID {
return peer, nil return peer, nil

View File

@@ -246,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers { for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers) peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), p, validatedPeers, account.GetActiveGroupUsers())
assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present") assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present") assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
} }
}) })
t.Run("check first peer map details", func(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers) peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 8) assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
@@ -509,7 +509,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}) })
t.Run("check port ranges support for older peers", func(t *testing.T) { t.Run("check port ranges support for older peers", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers) peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 1) assert.Len(t, peers, 1)
assert.Contains(t, peers, account.Peers["peerI"]) assert.Contains(t, peers, account.Peers["peerI"])
@@ -635,7 +635,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
} }
t.Run("check first peer map", func(t *testing.T) { t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@@ -665,7 +665,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map", func(t *testing.T) { t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@@ -697,7 +697,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) { t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@@ -719,7 +719,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map directional only", func(t *testing.T) { t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@@ -917,7 +917,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check. // will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -927,7 +927,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 7) assert.Len(t, firewallRules, 7)
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@@ -992,7 +992,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -1002,7 +1002,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -1017,19 +1017,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers) peers, firewallRules, _, _ := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers) peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -1044,14 +1044,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers) peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 3) assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3) assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers) peers, firewallRules, _, _ = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers, account.GetActiveGroupUsers())
assert.Len(t, peers, 5) assert.Len(t, peers, 5)
// assert peers from Group Swarm // assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -1910,16 +1910,16 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
if len(policyIDs) == 0 { if len(policyIDs) == 0 {
return nil, nil return nil, nil
} }
const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges, authorized_groups, authorized_user FROM policy_rules WHERE policy_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, policyIDs) rows, err := s.pool.Query(ctx, query, policyIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) {
var r types.PolicyRule var r types.PolicyRule
var dest, destRes, sources, sourceRes, ports, portRanges []byte var dest, destRes, sources, sourceRes, ports, portRanges, authorizedGroups []byte
var enabled, bidirectional sql.NullBool var enabled, bidirectional sql.NullBool
err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges, &authorizedGroups, &r.AuthorizedUser)
if err == nil { if err == nil {
if enabled.Valid { if enabled.Valid {
r.Enabled = enabled.Bool r.Enabled = enabled.Bool
@@ -1945,6 +1945,9 @@ func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*t
if portRanges != nil { if portRanges != nil {
_ = json.Unmarshal(portRanges, &r.PortRanges) _ = json.Unmarshal(portRanges, &r.PortRanges)
} }
if authorizedGroups != nil {
_ = json.Unmarshal(authorizedGroups, &r.AuthorizedGroups)
}
} }
return &r, err return &r, err
}) })

View File

@@ -16,6 +16,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -45,8 +46,10 @@ const (
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections. // nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
nativeSSHPortString = "22022" nativeSSHPortString = "22022"
nativeSSHPortNumber = 22022
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections. // defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
defaultSSHPortString = "22" defaultSSHPortString = "22"
defaultSSHPortNumber = 22
) )
type supportedFeatures struct { type supportedFeatures struct {
@@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap(
resourcePolicies map[string][]*Policy, resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter, routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics, metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap { ) *NetworkMap {
start := time.Now() start := time.Now()
peer := a.Peers[peerID] peer := a.Peers[peerID]
@@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
} }
} }
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap) aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers // exclude expired peers
var peersToConnect []*nbpeer.Peer var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer var expiredPeers []*nbpeer.Peer
@@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
OfflinePeers: expiredPeers, OfflinePeers: expiredPeers,
FirewallRules: firewallRules, FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
AuthorizedUsers: authorizedUsers,
EnableSSH: enableSSH,
} }
if metrics != nil { if metrics != nil {
@@ -1009,8 +1015,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer // GetPeerConnectionResources for a given peer
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer) generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
sshEnabled := false
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
@@ -1053,10 +1061,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
if peerInDestinations { if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN) generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
} }
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
sshEnabled = true
switch {
case len(rule.AuthorizedGroups) > 0:
for groupID, localUsers := range rule.AuthorizedGroups {
userIDs, ok := groupIDToUserIDs[groupID]
if !ok {
log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
continue
}
if len(localUsers) == 0 {
localUsers = []string{auth.Wildcard}
}
for _, localUser := range localUsers {
if authorizedUsers[localUser] == nil {
authorizedUsers[localUser] = make(map[string]struct{})
}
for _, userID := range userIDs {
authorizedUsers[localUser][userID] = struct{}{}
}
}
}
case rule.AuthorizedUser != "":
if authorizedUsers[auth.Wildcard] == nil {
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
}
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
default:
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
sshEnabled = true
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}
} }
} }
return getAccumulatedResources() peers, fwRules := getAccumulatedResources()
return peers, fwRules, authorizedUsers, sshEnabled
}
func (a *Account) getAllowedUserIDs() map[string]struct{} {
users := make(map[string]struct{})
for _, nbUser := range a.Users {
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
users[nbUser.Id] = struct{}{}
}
}
return users
} }
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls // connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
@@ -1081,12 +1137,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
peersExists[peer.ID] = struct{}{} peersExists[peer.ID] = struct{}{}
} }
protocol := rule.Protocol
if protocol == PolicyRuleProtocolNetbirdSSH {
protocol = PolicyRuleProtocolTCP
}
fr := FirewallRule{ fr := FirewallRule{
PolicyID: rule.ID, PolicyID: rule.ID,
PeerIP: peer.IP.String(), PeerIP: peer.IP.String(),
Direction: direction, Direction: direction,
Action: string(rule.Action), Action: string(rule.Action),
Protocol: string(rule.Protocol), Protocol: string(protocol),
} }
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
@@ -1108,6 +1169,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
} }
} }
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
}
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
for _, pr := range portRanges {
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
return true
}
}
return false
}
func portsIncludesSSH(ports []string) bool {
for _, port := range ports {
if port == defaultSSHPortString || port == nativeSSHPortString {
return true
}
}
return false
}
// getAllPeersFromGroups for given peer ID and list of groups // getAllPeersFromGroups for given peer ID and list of groups
// //
// Returns a list of peers from specified groups that pass specified posture checks // Returns a list of peers from specified groups that pass specified posture checks
@@ -1660,6 +1743,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
return nil return nil
} }
func (a *Account) GetActiveGroupUsers() map[string][]string {
allGroupID := ""
group, err := a.GetGroupAll()
if err != nil {
log.Errorf("failed to get group all: %v", err)
} else {
allGroupID = group.ID
}
groups := make(map[string][]string, len(a.GroupsG))
for _, user := range a.Users {
if !user.IsBlocked() && !user.IsServiceUser {
for _, groupID := range user.AutoGroups {
groups[groupID] = append(groups[groupID], user.Id)
}
groups[allGroupID] = append(groups[allGroupID], user.Id)
}
}
return groups
}
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules // expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule { func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion) features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
@@ -1691,7 +1794,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
expanded = append(expanded, &fr) expanded = append(expanded, &fr)
} }
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) { if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
expanded = addNativeSSHRule(base, expanded) expanded = addNativeSSHRule(base, expanded)
} }

View File

@@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
} }
} }
func Test_GetActiveGroupUsers(t *testing.T) {
tests := []struct {
name string
account *Account
expected map[string][]string
}{
{
name: "all users are active",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2", "group3"},
Blocked: false,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user3"},
"group2": {"user1", "user2"},
"group3": {"user2"},
"": {"user1", "user2", "user3"},
},
},
{
name: "some users are blocked",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2", "group3"},
Blocked: true,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1", "group3"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user3"},
"group2": {"user1"},
"group3": {"user3"},
"": {"user1", "user3"},
},
},
{
name: "all users are blocked",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1"},
Blocked: true,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2"},
Blocked: true,
},
},
},
expected: map[string][]string{},
},
{
name: "user with no auto groups",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user2"},
"": {"user1", "user2"},
},
},
{
name: "empty account",
account: &Account{
Users: map[string]*User{},
},
expected: map[string][]string{},
},
{
name: "multiple users in same group",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1"},
Blocked: false,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user2", "user3"},
"": {"user1", "user2", "user3"},
},
},
{
name: "user in multiple groups with blocked users",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2", "group3"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1", "group2"},
Blocked: true,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group3"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1"},
"group2": {"user1"},
"group3": {"user1", "user3"},
"": {"user1", "user3"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetActiveGroupUsers()
// Check that the number of groups matches
assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
// Check each group's users
for groupID, expectedUsers := range tt.expected {
actualUsers, exists := result[groupID]
assert.True(t, exists, "group %s should exist in result", groupID)
assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
}
// Ensure no extra groups in result
for groupID := range result {
_, exists := tt.expected[groupID]
assert.True(t, exists, "unexpected group %s in result", groupID)
}
})
}
}
func Test_FilterZoneRecordsForPeers(t *testing.T) { func Test_FilterZoneRecordsForPeers(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -38,6 +38,8 @@ type NetworkMap struct {
FirewallRules []*FirewallRule FirewallRules []*FirewallRule
RoutesFirewallRules []*RouteFirewallRule RoutesFirewallRules []*RouteFirewallRule
ForwardingRules []*ForwardingRule ForwardingRules []*ForwardingRule
AuthorizedUsers map[string]map[string]struct{}
EnableSSH bool
} }
func (nm *NetworkMap) Merge(other *NetworkMap) { func (nm *NetworkMap) Merge(other *NetworkMap) {

View File

@@ -69,7 +69,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap) normalizeAndSortNetworkMap(networkMap)
@@ -141,7 +141,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
b.Run("old builder", func(b *testing.B) { b.Run("old builder", func(b *testing.B) {
for range b.N { for range b.N {
for _, peerID := range peerIDs { for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
} }
} }
}) })
@@ -201,7 +201,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap) normalizeAndSortNetworkMap(networkMap)
@@ -320,7 +320,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) { b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs { for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
} }
} }
}) })
@@ -395,7 +395,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap) normalizeAndSortNetworkMap(networkMap)
@@ -550,7 +550,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) { b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs { for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
} }
} }
}) })
@@ -604,7 +604,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap) normalizeAndSortNetworkMap(networkMap)
@@ -730,7 +730,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(networkMap) normalizeAndSortNetworkMap(networkMap)
@@ -847,7 +847,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.Run("old builder after delete", func(b *testing.B) { b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs { for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
} }
} }
}) })

View File

@@ -23,6 +23,8 @@ const (
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic // PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
// PolicyRuleProtocolNetbirdSSH type of traffic
PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
) )
const ( const (
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
protocol = PolicyRuleProtocolUDP protocol = PolicyRuleProtocolUDP
case "icmp": case "icmp":
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'") return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
case "netbird-ssh":
return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
default: default:
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr) return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
} }

View File

@@ -80,6 +80,12 @@ type PolicyRule struct {
// PortRanges a list of port ranges. // PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"` PortRanges []RulePortRange `gorm:"serializer:json"`
// AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
AuthorizedGroups map[string][]string `gorm:"serializer:json"`
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
AuthorizedUser string
} }
// Copy returns a copy of a policy rule // Copy returns a copy of a policy rule
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
Protocol: pm.Protocol, Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)), Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)), PortRanges: make([]RulePortRange, len(pm.PortRanges)),
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
AuthorizedUser: pm.AuthorizedUser,
} }
copy(rule.Destinations, pm.Destinations) copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources) copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports) copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges) copy(rule.PortRanges, pm.PortRanges)
for k, v := range pm.AuthorizedGroups {
rule.AuthorizedGroups[k] = make([]string, len(v))
copy(rule.AuthorizedGroups[k], v)
}
return rule return rule
} }

View File

@@ -523,16 +523,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
} }
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( _, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err) return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
} }
if userHadPeers { updateAccountPeers = true
updateAccountPeers = true
}
err = transaction.SaveUser(ctx, updatedUser) err = transaction.SaveUser(ctx, updatedUser)
if err != nil { if err != nil {
@@ -581,7 +579,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
} }
} }
if settings.GroupsPropagationEnabled && updateAccountPeers { if updateAccountPeers {
if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil { if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
return nil, fmt.Errorf("failed to increment network serial: %w", err) return nil, fmt.Errorf("failed to increment network serial: %w", err)
} }

View File

@@ -1379,11 +1379,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
updateManager.CloseChannel(context.Background(), peer1.ID) updateManager.CloseChannel(context.Background(), peer1.ID)
}) })
// Creating a new regular user should not update account peers and not send peer update // Creating a new regular user should send peer update (as users are not filtered yet)
t.Run("creating new regular user with no groups", func(t *testing.T) { t.Run("creating new regular user with no groups", func(t *testing.T) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()
@@ -1402,11 +1402,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
} }
}) })
// updating user with no linked peers should not update account peers and not send peer update // updating user with no linked peers should update account peers and send peer update (as users are not filtered yet)
t.Run("updating user with no linked peers", func(t *testing.T) { t.Run("updating user with no linked peers", func(t *testing.T) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldReceiveUpdate(t, updMsg)
close(done) close(done)
}() }()

View File

@@ -488,6 +488,8 @@ components:
description: Indicates whether the peer is ephemeral or not description: Indicates whether the peer is ephemeral or not
type: boolean type: boolean
example: false example: false
local_flags:
$ref: '#/components/schemas/PeerLocalFlags'
required: required:
- city_name - city_name
- connected - connected
@@ -514,6 +516,49 @@ components:
- serial_number - serial_number
- extra_dns_labels - extra_dns_labels
- ephemeral - ephemeral
PeerLocalFlags:
type: object
properties:
rosenpass_enabled:
description: Indicates whether Rosenpass is enabled on this peer
type: boolean
example: true
rosenpass_permissive:
description: Indicates whether Rosenpass is in permissive mode or not
type: boolean
example: false
server_ssh_allowed:
description: Indicates whether SSH access this peer is allowed or not
type: boolean
example: true
disable_client_routes:
description: Indicates whether client routes are disabled on this peer or not
type: boolean
example: false
disable_server_routes:
description: Indicates whether server routes are disabled on this peer or not
type: boolean
example: false
disable_dns:
description: Indicates whether DNS management is disabled on this peer or not
type: boolean
example: false
disable_firewall:
description: Indicates whether firewall management is disabled on this peer or not
type: boolean
example: false
block_lan_access:
description: Indicates whether LAN access is blocked on this peer when used as a routing peer
type: boolean
example: false
block_inbound:
description: Indicates whether inbound traffic is blocked on this peer
type: boolean
example: false
lazy_connection_enabled:
description: Indicates whether lazy connection is enabled on this peer
type: boolean
example: false
PeerTemporaryAccessRequest: PeerTemporaryAccessRequest:
type: object type: object
properties: properties:
@@ -936,7 +981,7 @@ components:
protocol: protocol:
description: Policy rule type of the traffic description: Policy rule type of the traffic
type: string type: string
enum: ["all", "tcp", "udp", "icmp"] enum: ["all", "tcp", "udp", "icmp", "netbird-ssh"]
example: "tcp" example: "tcp"
ports: ports:
description: Policy rule affected ports description: Policy rule affected ports
@@ -949,6 +994,14 @@ components:
type: array type: array
items: items:
$ref: '#/components/schemas/RulePortRange' $ref: '#/components/schemas/RulePortRange'
authorized_groups:
description: Map of user group ids to a list of local users
type: object
additionalProperties:
type: array
items:
type: string
example: "group1"
required: required:
- name - name
- enabled - enabled

View File

@@ -130,10 +130,11 @@ const (
// Defines values for PolicyRuleProtocol. // Defines values for PolicyRuleProtocol.
const ( const (
PolicyRuleProtocolAll PolicyRuleProtocol = "all" PolicyRuleProtocolAll PolicyRuleProtocol = "all"
PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp" PolicyRuleProtocolIcmp PolicyRuleProtocol = "icmp"
PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp" PolicyRuleProtocolNetbirdSsh PolicyRuleProtocol = "netbird-ssh"
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp" PolicyRuleProtocolTcp PolicyRuleProtocol = "tcp"
PolicyRuleProtocolUdp PolicyRuleProtocol = "udp"
) )
// Defines values for PolicyRuleMinimumAction. // Defines values for PolicyRuleMinimumAction.
@@ -144,10 +145,11 @@ const (
// Defines values for PolicyRuleMinimumProtocol. // Defines values for PolicyRuleMinimumProtocol.
const ( const (
PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all" PolicyRuleMinimumProtocolAll PolicyRuleMinimumProtocol = "all"
PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp" PolicyRuleMinimumProtocolIcmp PolicyRuleMinimumProtocol = "icmp"
PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp" PolicyRuleMinimumProtocolNetbirdSsh PolicyRuleMinimumProtocol = "netbird-ssh"
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp" PolicyRuleMinimumProtocolTcp PolicyRuleMinimumProtocol = "tcp"
PolicyRuleMinimumProtocolUdp PolicyRuleMinimumProtocol = "udp"
) )
// Defines values for PolicyRuleUpdateAction. // Defines values for PolicyRuleUpdateAction.
@@ -158,10 +160,11 @@ const (
// Defines values for PolicyRuleUpdateProtocol. // Defines values for PolicyRuleUpdateProtocol.
const ( const (
PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all" PolicyRuleUpdateProtocolAll PolicyRuleUpdateProtocol = "all"
PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp" PolicyRuleUpdateProtocolIcmp PolicyRuleUpdateProtocol = "icmp"
PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp" PolicyRuleUpdateProtocolNetbirdSsh PolicyRuleUpdateProtocol = "netbird-ssh"
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" PolicyRuleUpdateProtocolTcp PolicyRuleUpdateProtocol = "tcp"
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
) )
// Defines values for ResourceType. // Defines values for ResourceType.
@@ -1077,7 +1080,8 @@ type Peer struct {
LastLogin time.Time `json:"last_login"` LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service // LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"` LastSeen time.Time `json:"last_seen"`
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"` LoginExpirationEnabled bool `json:"login_expiration_enabled"`
@@ -1167,7 +1171,8 @@ type PeerBatch struct {
LastLogin time.Time `json:"last_login"` LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service // LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"` LastSeen time.Time `json:"last_seen"`
LocalFlags *PeerLocalFlags `json:"local_flags,omitempty"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"` LoginExpirationEnabled bool `json:"login_expiration_enabled"`
@@ -1197,6 +1202,39 @@ type PeerBatch struct {
Version string `json:"version"` Version string `json:"version"`
} }
// PeerLocalFlags defines model for PeerLocalFlags.
type PeerLocalFlags struct {
// BlockInbound Indicates whether inbound traffic is blocked on this peer
BlockInbound *bool `json:"block_inbound,omitempty"`
// BlockLanAccess Indicates whether LAN access is blocked on this peer when used as a routing peer
BlockLanAccess *bool `json:"block_lan_access,omitempty"`
// DisableClientRoutes Indicates whether client routes are disabled on this peer or not
DisableClientRoutes *bool `json:"disable_client_routes,omitempty"`
// DisableDns Indicates whether DNS management is disabled on this peer or not
DisableDns *bool `json:"disable_dns,omitempty"`
// DisableFirewall Indicates whether firewall management is disabled on this peer or not
DisableFirewall *bool `json:"disable_firewall,omitempty"`
// DisableServerRoutes Indicates whether server routes are disabled on this peer or not
DisableServerRoutes *bool `json:"disable_server_routes,omitempty"`
// LazyConnectionEnabled Indicates whether lazy connection is enabled on this peer
LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"`
// RosenpassEnabled Indicates whether Rosenpass is enabled on this peer
RosenpassEnabled *bool `json:"rosenpass_enabled,omitempty"`
// RosenpassPermissive Indicates whether Rosenpass is in permissive mode or not
RosenpassPermissive *bool `json:"rosenpass_permissive,omitempty"`
// ServerSshAllowed Indicates whether SSH access this peer is allowed or not
ServerSshAllowed *bool `json:"server_ssh_allowed,omitempty"`
}
// PeerMinimum defines model for PeerMinimum. // PeerMinimum defines model for PeerMinimum.
type PeerMinimum struct { type PeerMinimum struct {
// Id Peer ID // Id Peer ID
@@ -1349,6 +1387,9 @@ type PolicyRule struct {
// Action Policy rule accept or drops packets // Action Policy rule accept or drops packets
Action PolicyRuleAction `json:"action"` Action PolicyRuleAction `json:"action"`
// AuthorizedGroups Map of user group ids to a list of local users
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations. // Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"` Bidirectional bool `json:"bidirectional"`
@@ -1393,6 +1434,9 @@ type PolicyRuleMinimum struct {
// Action Policy rule accept or drops packets // Action Policy rule accept or drops packets
Action PolicyRuleMinimumAction `json:"action"` Action PolicyRuleMinimumAction `json:"action"`
// AuthorizedGroups Map of user group ids to a list of local users
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations. // Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"` Bidirectional bool `json:"bidirectional"`
@@ -1426,6 +1470,9 @@ type PolicyRuleUpdate struct {
// Action Policy rule accept or drops packets // Action Policy rule accept or drops packets
Action PolicyRuleUpdateAction `json:"action"` Action PolicyRuleUpdateAction `json:"action"`
// AuthorizedGroups Map of user group ids to a list of local users
AuthorizedGroups *map[string][]string `json:"authorized_groups,omitempty"`
// Bidirectional Define if the rule is applicable in both directions, sources, and destinations. // Bidirectional Define if the rule is applicable in both directions, sources, and destinations.
Bidirectional bool `json:"bidirectional"` Bidirectional bool `json:"bidirectional"`

File diff suppressed because it is too large Load Diff

View File

@@ -332,6 +332,24 @@ message NetworkMap {
bool routesFirewallRulesIsEmpty = 11; bool routesFirewallRulesIsEmpty = 11;
repeated ForwardingRule forwardingRules = 12; repeated ForwardingRule forwardingRules = 12;
// SSHAuth represents SSH authorization configuration
SSHAuth sshAuth = 13;
}
message SSHAuth {
// UserIDClaim is the JWT claim to be used to get the users ID
string UserIDClaim = 1;
// AuthorizedUsers is a list of hashed user IDs authorized to access this peer via SSH
repeated bytes AuthorizedUsers = 2;
// MachineUsers is a map of machine user names to their corresponding indexes in the AuthorizedUsers list
map<string, MachineUserIndexes> machine_users = 3;
}
message MachineUserIndexes {
repeated uint32 indexes = 1;
} }
// RemotePeerConfig represents a configuration of a remote peer. // RemotePeerConfig represents a configuration of a remote peer.

View File

@@ -0,0 +1,28 @@
package sshauth
import (
"encoding/hex"
"golang.org/x/crypto/blake2b"
)
// UserIDHash represents a hashed user ID (BLAKE2b-128)
type UserIDHash [16]byte
// HashUserID hashes a user ID using BLAKE2b-128 and returns the hash value
// This function must produce the same hash on both client and management server
func HashUserID(userID string) (UserIDHash, error) {
hash, err := blake2b.New(16, nil)
if err != nil {
return UserIDHash{}, err
}
hash.Write([]byte(userID))
var result UserIDHash
copy(result[:], hash.Sum(nil))
return result, nil
}
// String returns the hexadecimal string representation of the hash
func (h UserIDHash) String() string {
return hex.EncodeToString(h[:])
}

View File

@@ -0,0 +1,210 @@
package sshauth
import (
"testing"
)
func TestHashUserID(t *testing.T) {
tests := []struct {
name string
userID string
}{
{
name: "simple user ID",
userID: "user@example.com",
},
{
name: "UUID format",
userID: "550e8400-e29b-41d4-a716-446655440000",
},
{
name: "numeric ID",
userID: "12345",
},
{
name: "empty string",
userID: "",
},
{
name: "special characters",
userID: "user+test@domain.com",
},
{
name: "unicode characters",
userID: "用户@example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hash, err := HashUserID(tt.userID)
if err != nil {
t.Errorf("HashUserID() error = %v, want nil", err)
return
}
// Verify hash is non-zero for non-empty inputs
if tt.userID != "" && hash == [16]byte{} {
t.Errorf("HashUserID() returned zero hash for non-empty input")
}
})
}
}
func TestHashUserID_Consistency(t *testing.T) {
userID := "test@example.com"
hash1, err1 := HashUserID(userID)
if err1 != nil {
t.Fatalf("First HashUserID() error = %v", err1)
}
hash2, err2 := HashUserID(userID)
if err2 != nil {
t.Fatalf("Second HashUserID() error = %v", err2)
}
if hash1 != hash2 {
t.Errorf("HashUserID() is not consistent: got %v and %v for same input", hash1, hash2)
}
}
func TestHashUserID_Uniqueness(t *testing.T) {
tests := []struct {
userID1 string
userID2 string
}{
{"user1@example.com", "user2@example.com"},
{"alice@domain.com", "bob@domain.com"},
{"test", "test1"},
{"", "a"},
}
for _, tt := range tests {
hash1, err1 := HashUserID(tt.userID1)
if err1 != nil {
t.Fatalf("HashUserID(%s) error = %v", tt.userID1, err1)
}
hash2, err2 := HashUserID(tt.userID2)
if err2 != nil {
t.Fatalf("HashUserID(%s) error = %v", tt.userID2, err2)
}
if hash1 == hash2 {
t.Errorf("HashUserID() collision: %s and %s produced same hash %v", tt.userID1, tt.userID2, hash1)
}
}
}
func TestUserIDHash_String(t *testing.T) {
tests := []struct {
name string
hash UserIDHash
expected string
}{
{
name: "zero hash",
hash: [16]byte{},
expected: "00000000000000000000000000000000",
},
{
name: "small value",
hash: [16]byte{15: 0xff},
expected: "000000000000000000000000000000ff",
},
{
name: "large value",
hash: [16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe},
expected: "0000000000000000deadbeefcafebabe",
},
{
name: "max value",
hash: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
expected: "ffffffffffffffffffffffffffffffff",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.hash.String()
if result != tt.expected {
t.Errorf("UserIDHash.String() = %v, want %v", result, tt.expected)
}
})
}
}
func TestUserIDHash_String_Length(t *testing.T) {
// Test that String() always returns 32 hex characters (16 bytes * 2)
userID := "test@example.com"
hash, err := HashUserID(userID)
if err != nil {
t.Fatalf("HashUserID() error = %v", err)
}
result := hash.String()
if len(result) != 32 {
t.Errorf("UserIDHash.String() length = %d, want 32", len(result))
}
// Verify it's valid hex
for i, c := range result {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("UserIDHash.String() contains non-hex character at position %d: %c", i, c)
}
}
}
func TestHashUserID_KnownValues(t *testing.T) {
// Test with known BLAKE2b-128 values to ensure correct implementation
tests := []struct {
name string
userID string
expected UserIDHash
}{
{
name: "empty string",
userID: "",
// BLAKE2b-128 of empty string
expected: [16]byte{0xca, 0xe6, 0x69, 0x41, 0xd9, 0xef, 0xbd, 0x40, 0x4e, 0x4d, 0x88, 0x75, 0x8e, 0xa6, 0x76, 0x70},
},
{
name: "single character 'a'",
userID: "a",
// BLAKE2b-128 of "a"
expected: [16]byte{0x27, 0xc3, 0x5e, 0x6e, 0x93, 0x73, 0x87, 0x7f, 0x29, 0xe5, 0x62, 0x46, 0x4e, 0x46, 0x49, 0x7e},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hash, err := HashUserID(tt.userID)
if err != nil {
t.Errorf("HashUserID() error = %v", err)
return
}
if hash != tt.expected {
t.Errorf("HashUserID(%q) = %x, want %x",
tt.userID, hash, tt.expected)
}
})
}
}
func BenchmarkHashUserID(b *testing.B) {
userID := "user@example.com"
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = HashUserID(userID)
}
}
func BenchmarkUserIDHash_String(b *testing.B) {
hash := UserIDHash([16]byte{8: 0xde, 9: 0xad, 10: 0xbe, 11: 0xef, 12: 0xca, 13: 0xfe, 14: 0xba, 15: 0xbe})
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = hash.String()
}
}