mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-28 10:49:55 +00:00
Compare commits
10 Commits
feature/io
...
pascal-fil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62ffa08744 | ||
|
|
d8e7f2e9e6 | ||
|
|
1205641b44 | ||
|
|
56e8215ebe | ||
|
|
9b768d1773 | ||
|
|
33954ea15e | ||
|
|
4c4434a871 | ||
|
|
7873f337df | ||
|
|
07101c59ac | ||
|
|
51b6f6291b |
@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: string(activeProf.ID),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -85,9 +85,7 @@ type Client struct {
|
||||
|
||||
stateMu sync.RWMutex
|
||||
connectClient *internal.ConnectClient
|
||||
// config holds the active configuration once Run has loaded it. Consumed by
|
||||
// the in-app SSH client for the NetBird SSH key and the OAuth flow.
|
||||
config *profilemanager.Config
|
||||
config *profilemanager.Config
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
@@ -172,7 +170,6 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
cfg.WgIface = interfaceName
|
||||
c.config = cfg
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, connectClient)
|
||||
@@ -701,13 +698,6 @@ func (c *Client) stateSnapshot() (*profilemanager.Config, *internal.ConnectClien
|
||||
return c.config, c.connectClient
|
||||
}
|
||||
|
||||
// sshState returns the active config and the running connect client for the
|
||||
// in-app SSH client. Both are nil until Run has loaded the config and started
|
||||
// the tunnel.
|
||||
func (c *Client) sshState() (*profilemanager.Config, *internal.ConnectClient) {
|
||||
return c.stateSnapshot()
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
ds := d.String()
|
||||
dotIndex := strings.Index(ds, ".")
|
||||
|
||||
@@ -1,434 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package NetBirdSDK
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
)
|
||||
|
||||
const (
|
||||
sshDialTimeout = 30 * time.Second
|
||||
sshDetectionTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// SSHTerminalListener receives SSH session events. It is implemented in Swift.
|
||||
//
|
||||
// All callbacks are invoked from goroutines and may run concurrently with each
|
||||
// other; the implementation must be safe to call from any thread.
|
||||
type SSHTerminalListener interface {
|
||||
OnConnected()
|
||||
OnData(data []byte)
|
||||
OnClose(reason string)
|
||||
OnError(message string)
|
||||
}
|
||||
|
||||
// SSHClient is a NetBird-aware SSH client exposed to Swift via gomobile.
|
||||
//
|
||||
// It dials through the running NetBird tunnel and runs a standard SSH session
|
||||
// on top with PTY enabled. Host-key verification uses the NetBird-provided
|
||||
// peer SSH host keys, identical to the desktop client.
|
||||
type SSHClient struct {
|
||||
nb *Client
|
||||
mu sync.Mutex
|
||||
listener SSHTerminalListener
|
||||
urlOpener URLOpener
|
||||
|
||||
sshClient *gossh.Client
|
||||
session *gossh.Session
|
||||
stdin io.WriteCloser
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewSSHClient creates a new SSH client bound to the running NetBird Client.
|
||||
func NewSSHClient(c *Client) *SSHClient {
|
||||
return &SSHClient{nb: c}
|
||||
}
|
||||
|
||||
// SetListener registers the Swift listener. Must be called before Connect to
|
||||
// receive any events.
|
||||
func (s *SSHClient) SetListener(l SSHTerminalListener) {
|
||||
s.mu.Lock()
|
||||
s.listener = l
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetURLOpener registers the Swift URL opener used to display the device-code
|
||||
// authorization page in an in-app browser when the target peer requires JWT
|
||||
// authentication. Must be set before Connect to be effective.
|
||||
func (s *SSHClient) SetURLOpener(opener URLOpener) {
|
||||
s.mu.Lock()
|
||||
s.urlOpener = opener
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Connect dials the SSH server through the NetBird tunnel and performs the
|
||||
// SSH handshake. It auto-detects the server type via SSH banner inspection
|
||||
// and selects the appropriate authentication path:
|
||||
//
|
||||
// - NetBird-SSH server requiring JWT: launches the OAuth 2.0 device-code
|
||||
// flow, opens the verification URL through the registered URLOpener, and
|
||||
// uses the resulting token as the SSH password. Host-key verification
|
||||
// uses the NetBird peer registry.
|
||||
// - NetBird-SSH server without JWT: authenticates with the NetBird SSH
|
||||
// private key. Host-key verification uses the NetBird peer registry.
|
||||
// - Regular SSH server (e.g. OpenSSH): authenticates with the NetBird key
|
||||
// first (so a user-installed NetBird public key works), then falls back
|
||||
// to the supplied password if non-empty. Host-key verification is
|
||||
// disabled (TOFU pending).
|
||||
//
|
||||
// The password parameter is only consulted for regular SSH servers.
|
||||
func (s *SSHClient) Connect(host string, port int, user, password string) error {
|
||||
cfg, cc := s.nb.sshState()
|
||||
if cc == nil {
|
||||
return errors.New("netbird client not running")
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New("netbird config not loaded")
|
||||
}
|
||||
engine := cc.Engine()
|
||||
if engine == nil {
|
||||
return errors.New("netbird engine not available")
|
||||
}
|
||||
|
||||
serverType := detectServerType(host, port)
|
||||
log.Infof("SSH server type for %s:%d: %s", host, port, serverType)
|
||||
|
||||
authMethods, hostKeyCallback, err := s.buildAuth(cfg, engine, serverType, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientConfig := &gossh.ClientConfig{
|
||||
User: user,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
Timeout: sshDialTimeout,
|
||||
}
|
||||
return s.dialAndHandshake(host, port, clientConfig)
|
||||
}
|
||||
|
||||
// StartSession requests a PTY and starts an interactive shell. Output from
|
||||
// the session is forwarded to the listener via OnData.
|
||||
func (s *SSHClient) StartSession(cols, rows int) error {
|
||||
log.Debugf("SSH: starting session %dx%d", cols, rows)
|
||||
s.mu.Lock()
|
||||
sshClient := s.sshClient
|
||||
s.mu.Unlock()
|
||||
|
||||
if sshClient == nil {
|
||||
return errors.New("ssh client not connected")
|
||||
}
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
|
||||
modes := gossh.TerminalModes{
|
||||
gossh.ECHO: 1,
|
||||
gossh.TTY_OP_ISPEED: 14400,
|
||||
gossh.TTY_OP_OSPEED: 14400,
|
||||
gossh.VINTR: 3,
|
||||
gossh.VQUIT: 28,
|
||||
gossh.VERASE: 127,
|
||||
}
|
||||
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||
closeQuiet(session, "session after pty error")
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stdin error")
|
||||
return fmt.Errorf("stdin pipe: %w", err)
|
||||
}
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stdout error")
|
||||
return fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := session.StderrPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stderr error")
|
||||
return fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
closeQuiet(session, "session after shell error")
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.session = session
|
||||
s.stdin = stdin
|
||||
s.mu.Unlock()
|
||||
|
||||
go s.readLoop(stdout, "stdout")
|
||||
go s.readLoop(stderr, "stderr")
|
||||
log.Debug("SSH: session started, shell running")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write sends data to the SSH session stdin.
|
||||
func (s *SSHClient) Write(data []byte) error {
|
||||
s.mu.Lock()
|
||||
stdin := s.stdin
|
||||
s.mu.Unlock()
|
||||
if stdin == nil {
|
||||
return errors.New("ssh session not started")
|
||||
}
|
||||
if _, err := stdin.Write(data); err != nil {
|
||||
return fmt.Errorf("write stdin: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resize updates the PTY window size.
|
||||
func (s *SSHClient) Resize(cols, rows int) error {
|
||||
s.mu.Lock()
|
||||
session := s.session
|
||||
s.mu.Unlock()
|
||||
if session == nil {
|
||||
return errors.New("ssh session not started")
|
||||
}
|
||||
return session.WindowChange(rows, cols)
|
||||
}
|
||||
|
||||
// Close terminates the SSH session and underlying connection. Safe to call
|
||||
// multiple times.
|
||||
func (s *SSHClient) Close() error {
|
||||
s.mu.Lock()
|
||||
sshClient := s.sshClient
|
||||
session := s.session
|
||||
stdin := s.stdin
|
||||
s.sshClient = nil
|
||||
s.session = nil
|
||||
s.stdin = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if stdin != nil {
|
||||
if err := stdin.Close(); err != nil {
|
||||
log.Debugf("ssh: stdin close: %v", err)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh: session close: %v", err)
|
||||
}
|
||||
}
|
||||
var firstErr error
|
||||
if sshClient != nil {
|
||||
if err := sshClient.Close(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
s.notifyClose("closed by client")
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SSHClient) buildAuth(cfg *profilemanager.Config, engine *internal.Engine,
|
||||
serverType detection.ServerType, password string) ([]gossh.AuthMethod, gossh.HostKeyCallback, error) {
|
||||
|
||||
switch serverType {
|
||||
case detection.ServerTypeNetBirdJWT:
|
||||
token, err := s.requestJWTToken(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("jwt: %w", err)
|
||||
}
|
||||
auths := []gossh.AuthMethod{gossh.Password(token)}
|
||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
||||
|
||||
case detection.ServerTypeNetBirdNoJWT:
|
||||
if cfg.SSHKey == "" {
|
||||
return nil, nil, errors.New("no NetBird SSH key available")
|
||||
}
|
||||
signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse netbird ssh key: %w", err)
|
||||
}
|
||||
auths := []gossh.AuthMethod{gossh.PublicKeys(signer)}
|
||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
||||
|
||||
default: // regular SSH
|
||||
var auths []gossh.AuthMethod
|
||||
if cfg.SSHKey != "" {
|
||||
if signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey)); err == nil {
|
||||
auths = append(auths, gossh.PublicKeys(signer))
|
||||
} else {
|
||||
log.Debugf("ssh: parse netbird key for regular auth: %v", err)
|
||||
}
|
||||
}
|
||||
if password != "" {
|
||||
pw := password
|
||||
auths = append(auths, gossh.Password(pw))
|
||||
auths = append(auths, gossh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) {
|
||||
answers := make([]string, len(questions))
|
||||
for i := range questions {
|
||||
answers[i] = pw
|
||||
}
|
||||
return answers, nil
|
||||
}))
|
||||
}
|
||||
if len(auths) == 0 {
|
||||
return nil, nil, errors.New("no auth method available: provide a password or configure NetBird SSH key")
|
||||
}
|
||||
return auths, gossh.InsecureIgnoreHostKey(), nil // nolint:gosec // TOFU not yet implemented
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHClient) requestJWTToken(cfg *profilemanager.Config) (string, error) {
|
||||
s.mu.Lock()
|
||||
urlOpener := s.urlOpener
|
||||
s.mu.Unlock()
|
||||
if urlOpener == nil {
|
||||
return "", errors.New("URL opener not configured for JWT auth")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
flow, err := auth.NewOAuthFlow(ctx, cfg, false, true, profilemanager.GetLoginHint())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create oauth flow: %w", err)
|
||||
}
|
||||
|
||||
flowInfo, err := flow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request auth info: %w", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
tokenInfo, err := flow.WaitToken(ctx, flowInfo)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wait for token: %w", err)
|
||||
}
|
||||
|
||||
token := tokenInfo.GetTokenToUse()
|
||||
if token == "" {
|
||||
return "", errors.New("empty token returned by IdP")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *SSHClient) dialAndHandshake(host string, port int, clientConfig *gossh.ClientConfig) error {
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
log.Infof("SSH: connecting to %s as %s", addr, clientConfig.User)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
var dialer net.Dialer
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
sshConn, chans, reqs, err := gossh.NewClientConn(conn, addr, clientConfig)
|
||||
if err != nil {
|
||||
if cerr := conn.Close(); cerr != nil {
|
||||
log.Debugf("ssh: close after handshake error: %v", cerr)
|
||||
}
|
||||
return fmt.Errorf("ssh handshake: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.sshClient = gossh.NewClient(sshConn, chans, reqs)
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("SSH: connected to %s", addr)
|
||||
if listener != nil {
|
||||
listener.OnConnected()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSHClient) readLoop(r io.Reader, name string) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
s.mu.Lock()
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
if listener != nil {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
listener.OnData(chunk)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh %s read: %v", name, err)
|
||||
}
|
||||
s.notifyClose(err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHClient) notifyClose(reason string) {
|
||||
s.mu.Lock()
|
||||
if s.closed {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
if listener != nil {
|
||||
listener.OnClose(reason)
|
||||
}
|
||||
}
|
||||
|
||||
// engineHostKeyVerifier adapts *internal.Engine to nbssh.HostKeyVerifier.
|
||||
type engineHostKeyVerifier struct {
|
||||
engine *internal.Engine
|
||||
}
|
||||
|
||||
func (v *engineHostKeyVerifier) VerifySSHHostKey(peerAddress string, presented []byte) error {
|
||||
storedKey, found := v.engine.GetPeerSSHKey(peerAddress)
|
||||
if !found {
|
||||
return nbssh.ErrPeerNotFound
|
||||
}
|
||||
return nbssh.VerifyHostKey(storedKey, presented, peerAddress)
|
||||
}
|
||||
|
||||
func detectServerType(host string, port int) detection.ServerType {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDetectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
log.Debugf("ssh: server detection for %s:%d failed: %v (assuming regular SSH)", host, port, err)
|
||||
return detection.ServerTypeRegular
|
||||
}
|
||||
return serverType
|
||||
}
|
||||
|
||||
func closeQuiet(c io.Closer, label string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if err := c.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh: close %s: %v", label, err)
|
||||
}
|
||||
}
|
||||
@@ -1916,6 +1916,117 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
// Establish a session so the matching-token disconnect is actually applied.
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
// Install the mock only now, so the assertion observes the disconnect, not
|
||||
// the earlier connect.
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
// expected: disconnect re-armed the inactivity expiry timer
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
|
||||
// stays disabled, so disconnect must not schedule anything.
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// expected: nothing scheduled
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
|
||||
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -106,12 +106,8 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
change, mustContain, mustExclude := r.build(t, s, ctx)
|
||||
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
|
||||
|
||||
for _, id := range mustContain {
|
||||
assert.Contains(t, affected, id, "expected peer to be affected")
|
||||
}
|
||||
for _, id := range mustExclude {
|
||||
assert.NotContains(t, affected, id, "peer must not be affected")
|
||||
}
|
||||
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
|
||||
assert.NotContains(t, affected, mustExclude, "peer must not be affected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,33 +96,54 @@ func affectedGroupID(i int) string { return fmt.Sprintf("affected-grp-%d", i)
|
||||
func affectedGroupName(i int) string { return fmt.Sprintf("AffectedGroup%d", i) }
|
||||
|
||||
func TestCollectGroupChange_PolicyLinked(t *testing.T) {
|
||||
manager, s, accountID, _, groupIDs := setupAffectedPeersTest(t)
|
||||
manager, s, accountID, peerIDs, groupIDs := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := manager.SavePolicy(ctx, accountID, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[0], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypeHost},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, _ := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[1]})
|
||||
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[0]})
|
||||
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
@@ -133,20 +154,44 @@ func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[4], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.Contains(t, directPeers, peerIDs[3])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[4]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[3]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T) {
|
||||
@@ -168,8 +213,7 @@ func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Empty(t, directPeers, "non-peer resources should not produce direct peer IDs")
|
||||
}
|
||||
|
||||
@@ -373,17 +417,11 @@ func TestCollectGroupChange_MultipleEntities(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.NotContains(t, groups, groupIDs[2])
|
||||
assert.NotContains(t, groups, groupIDs[3])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Empty(t, directPeers)
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[3]})
|
||||
assert.Contains(t, groups, groupIDs[2])
|
||||
assert.Contains(t, groups, groupIDs[3])
|
||||
assert.NotContains(t, groups, groupIDs[0])
|
||||
assert.NotContains(t, groups, groupIDs[1])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[2], groupIDs[3]})
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
@@ -474,7 +512,7 @@ func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
|
||||
@@ -659,9 +697,13 @@ func TestResolveAffectedPeers_PeerInMultipleGroups(t *testing.T) {
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// peer0 is in group0 AND group1, so both policies apply
|
||||
// peer0 is in group0 AND group1, so both policies apply. A peer change folds
|
||||
// only the changed peer plus the opposite side of each rule: group2 (peer2) via
|
||||
// the group0 policy and group3 (peer3) via the group1 policy. peer1, a co-member
|
||||
// of group1, is a sibling of the changed peer and must NOT refresh.
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.NotContains(t, result, peerIDs[1], "co-member of the changed peer's group must not refresh")
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
@@ -697,7 +739,7 @@ func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0], peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[1], peerIDs[3]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_SharedGroupAcrossPolicyAndRoute(t *testing.T) {
|
||||
|
||||
@@ -221,15 +221,17 @@ func Collect(ctx context.Context, s store.Store, accountID string, c Change) (gr
|
||||
|
||||
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
|
||||
r := &resolver{
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
networkIDs: make(map[string]struct{}),
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
networkIDs: make(map[string]struct{}),
|
||||
sourceOriginatedNetworkIDs: make(map[string]struct{}),
|
||||
changedGroupIDs: toSet(c.ChangedGroupIDs),
|
||||
}
|
||||
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
|
||||
r.seedChangedGroupsFromPeers()
|
||||
@@ -239,6 +241,9 @@ func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change
|
||||
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
|
||||
// the group-driven walkers fire for memberships, not just direct peer references.
|
||||
// These seeded groups are for MATCHING only — folding the changed entity's own
|
||||
// side is gated on changedGroupIDs (the caller-reported groups), so a seeded group
|
||||
// never folds its whole membership; only the changed peer itself folds in.
|
||||
func (r *resolver) seedChangedGroupsFromPeers() {
|
||||
if len(r.changedPeerSet) == 0 {
|
||||
return
|
||||
@@ -292,6 +297,18 @@ type resolver struct {
|
||||
|
||||
matchedPolicies []*types.Policy
|
||||
networkIDs map[string]struct{}
|
||||
// sourceOriginatedNetworkIDs are networks marked affected only because a
|
||||
// source-side change targets a resource on them (bridgeSourceToRouters). Their
|
||||
// routers must refresh, but the policy sources must not be folded back: a
|
||||
// changed source propagates only to the opposite (router) side, never to its
|
||||
// co-sources. Networks marked by a router/resource/network change are absent
|
||||
// here and do fold sources, since the destination side itself changed.
|
||||
sourceOriginatedNetworkIDs map[string]struct{}
|
||||
|
||||
// changedGroupIDs are the groups the caller reported as changed via
|
||||
// Change.ChangedGroupIDs (NOT the peer-seeded ones in changedGroupSet). Only
|
||||
// these fold their whole membership; a peer-seeded group folds the peer alone.
|
||||
changedGroupIDs map[string]struct{}
|
||||
}
|
||||
|
||||
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
|
||||
@@ -445,21 +462,88 @@ func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromPolicies folds, for every policy a changed group or peer touches:
|
||||
// the opposite side of the matching rule, the changed entity's own side (the
|
||||
// changed group itself, or the changed peer alone — never the changed side's
|
||||
// sibling groups or co-members), and records the policy for the resource<->router
|
||||
// bridge. A changed peer is mapped to its groups in changedGroupSet up front (see
|
||||
// seedChangedGroupsFromPeers); changedGroupIDs holds only the caller-reported
|
||||
// groups, so a peer-seeded group does not fold its whole membership.
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
|
||||
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
if !r.collectPolicyDirectional(policy) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched directionally", policy.ID, policy.Name)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
// collectPolicyDirectional folds one policy's affected groups/peers and reports
|
||||
// whether it matched a changed group or peer at all (so the caller can record it
|
||||
// for the bridge even when the opposite side is a resource, not a group).
|
||||
func (r *resolver) collectPolicyDirectional(policy *types.Policy) bool {
|
||||
matched := false
|
||||
for _, rule := range policy.Rules {
|
||||
matched = r.foldRuleSide(rule.Sources, rule.Destinations, rule.DestinationResource) || matched
|
||||
matched = r.foldRuleSide(rule.Destinations, rule.Sources, rule.SourceResource) || matched
|
||||
|
||||
if isDirectPeerInSet(rule.SourceResource, r.changedPeerSet) {
|
||||
r.peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
addAll(r.groupSet, rule.Destinations)
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
r.peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
matched = true
|
||||
}
|
||||
if isDirectPeerInSet(rule.DestinationResource, r.changedPeerSet) {
|
||||
r.peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
addAll(r.groupSet, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
r.peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
return matched
|
||||
}
|
||||
|
||||
// foldRuleSide handles a changed group on `near` (Sources or Destinations): it
|
||||
// folds the `far` (opposite) groups and far resource peer, the changed group(s)
|
||||
// themselves (caller-reported groups only — not seeded ones, so a changed peer's
|
||||
// group does not pull in its members), and the changed peers seeded from those
|
||||
// groups (the peer alone). Returns whether the side matched.
|
||||
func (r *resolver) foldRuleSide(near, far []string, farResource types.Resource) bool {
|
||||
if !anyInSet(near, r.changedGroupSet) {
|
||||
return false
|
||||
}
|
||||
addAll(r.groupSet, far)
|
||||
if farResource.Type == types.ResourceTypePeer && farResource.ID != "" {
|
||||
r.peerSet[farResource.ID] = struct{}{}
|
||||
}
|
||||
for _, gID := range near {
|
||||
if _, ok := r.changedGroupIDs[gID]; ok {
|
||||
r.groupSet[gID] = struct{}{} // changed group itself -> its members
|
||||
}
|
||||
r.foldChangedPeersInGroup(gID) // a changed peer in this group -> the peer alone
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// foldChangedPeersInGroup folds changed peers that belong to groupID directly into
|
||||
// peerSet (the peer only, never its co-members).
|
||||
func (r *resolver) foldChangedPeersInGroup(groupID string) {
|
||||
if len(r.changedPeerSet) == 0 {
|
||||
return
|
||||
}
|
||||
members := r.snap.groupPeers[groupID]
|
||||
for pID := range r.changedPeerSet {
|
||||
if _, ok := members[pID]; ok {
|
||||
r.peerSet[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromRoutes() {
|
||||
for _, rt := range r.snap.routes {
|
||||
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
|
||||
@@ -588,6 +672,11 @@ func (r *resolver) bridgeSourceToRouters() {
|
||||
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
|
||||
setToSlice(resourceIDs), setToSlice(networkIDs))
|
||||
for id := range networkIDs {
|
||||
// Mark source-originated unless a router/resource/network change already
|
||||
// marked this network directly (then it folds sources back).
|
||||
if _, ok := r.networkIDs[id]; !ok {
|
||||
r.sourceOriginatedNetworkIDs[id] = struct{}{}
|
||||
}
|
||||
r.networkIDs[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
@@ -602,11 +691,19 @@ func (r *resolver) bridgeRoutersToSources() {
|
||||
|
||||
r.foldRoutersOnNetworks(r.networkIDs)
|
||||
|
||||
// Sources are folded back only for networks the destination side itself changed
|
||||
// (router/resource/network change). Networks reached only because a source-side
|
||||
// change targets their resource must not refresh the policy's sources — the
|
||||
// changed source propagates to the router side, not back to its co-sources.
|
||||
resourceIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; ok {
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, sourceOriginated := r.sourceOriginatedNetworkIDs[resource.NetworkID]; sourceOriginated {
|
||||
continue
|
||||
}
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
}
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
@@ -734,24 +831,6 @@ func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]str
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if _, ok := ids[id]; ok {
|
||||
|
||||
@@ -80,26 +80,6 @@ func TestChangeIsEmpty(t *testing.T) {
|
||||
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
|
||||
}
|
||||
|
||||
func TestPolicyReferencesGroups(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
|
||||
|
||||
|
||||
@@ -188,6 +188,15 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
}
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
|
||||
} else if settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user