mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-23 17:19:54 +00:00
Compare commits
3 Commits
windows-dn
...
fix-ssh-au
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
486cd4c0e3 | ||
|
|
b59382d4f2 | ||
|
|
e916f12cca |
@@ -394,6 +394,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
for _, m := range mappings[i:end] {
|
||||
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||
}
|
||||
m.AuthToken = token
|
||||
}
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
@@ -425,18 +432,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
|
||||
oidcCfg := s.GetOIDCValidationConfig()
|
||||
var mappings []*proto.ProxyMapping
|
||||
for _, service := range services {
|
||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||
assert.Empty(t, stream.messages[0].Mapping)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
type hookingStream struct {
|
||||
grpc.ServerStream
|
||||
onSend func(*proto.GetMappingUpdateResponse)
|
||||
}
|
||||
|
||||
func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||
if s.onSend != nil {
|
||||
s.onSend(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *hookingStream) Context() context.Context { return context.Background() }
|
||||
func (s *hookingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *hookingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *hookingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *hookingStream) SendMsg(any) error { return nil }
|
||||
func (s *hookingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 2
|
||||
const totalServices = 6
|
||||
const ttl = 100 * time.Millisecond
|
||||
const sendDelay = 200 * time.Millisecond
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
s.tokenTTL = ttl
|
||||
|
||||
var validateErrs []error
|
||||
stream := &hookingStream{
|
||||
onSend: func(resp *proto.GetMappingUpdateResponse) {
|
||||
for _, m := range resp.Mapping {
|
||||
if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil {
|
||||
validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err))
|
||||
}
|
||||
}
|
||||
time.Sleep(sendDelay)
|
||||
},
|
||||
}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Empty(t, validateErrs,
|
||||
"tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness")
|
||||
}
|
||||
|
||||
@@ -326,17 +326,25 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context,
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -892,7 +892,12 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
|
||||
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
// Auth is collected when this peer serves the rule. For bidirectional
|
||||
// rules the peer-in-sources side also serves inbound traffic, so it
|
||||
// must be treated as a destination too.
|
||||
peerServesAuth := peerInDestinations || (rule.Bidirectional && peerInSources)
|
||||
|
||||
if peerServesAuth && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
sshEnabled = true
|
||||
switch {
|
||||
case len(rule.AuthorizedGroups) > 0:
|
||||
@@ -924,7 +929,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
|
||||
default:
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||
} else if peerServesAuth && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||
sshEnabled = true
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
|
||||
@@ -341,7 +341,12 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
|
||||
for _, srcGroupID := range rule.Sources {
|
||||
relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
// SSH auth requirements are gathered whenever this peer serves
|
||||
// the rule. For bidirectional rules the peer-in-sources side
|
||||
// also serves inbound traffic and must be treated as a destination.
|
||||
if peerInDestinations || (rule.Bidirectional && peerInSources) {
|
||||
if rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
switch {
|
||||
case len(rule.AuthorizedGroups) > 0:
|
||||
|
||||
@@ -221,7 +221,12 @@ func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) (
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
|
||||
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
// Auth is collected when this peer serves the rule. For bidirectional
|
||||
// rules the peer-in-sources side also serves inbound traffic, so it
|
||||
// must be treated as a destination too.
|
||||
peerServesAuth := peerInDestinations || (rule.Bidirectional && peerInSources)
|
||||
|
||||
if peerServesAuth && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
sshEnabled = true
|
||||
switch {
|
||||
case len(rule.AuthorizedGroups) > 0:
|
||||
@@ -252,7 +257,7 @@ func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) (
|
||||
default:
|
||||
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
|
||||
}
|
||||
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
|
||||
} else if peerServesAuth && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
|
||||
sshEnabled = true
|
||||
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
|
||||
}
|
||||
@@ -557,7 +562,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
|
||||
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
|
||||
@@ -980,6 +980,44 @@ func TestComponents_SSHAuthorizedUsersContent(t *testing.T) {
|
||||
assert.True(t, hasRoot || hasAdmin, "AuthorizedUsers should contain 'root' or 'admin' machine user mapping")
|
||||
}
|
||||
|
||||
// TestComponents_SSHAuthorizedUsersBidirectionalSource verifies that a peer
|
||||
// on the sources side of a bidirectional NetbirdSSH rule receives the rule's
|
||||
// authorized users. The reverse direction (destinations -> sources) makes
|
||||
// the source-side peer a destination too, so it must be able to authorize
|
||||
// inbound SSH from the rule's destinations.
|
||||
func TestComponents_SSHAuthorizedUsersBidirectionalSource(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
|
||||
|
||||
account.Users["user-dev"] = &types.User{Id: "user-dev", Role: types.UserRoleUser, AccountID: "test-account", AutoGroups: []string{"ssh-users"}}
|
||||
account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}}
|
||||
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-ssh-bidir", Name: "Bidirectional SSH", Enabled: true, AccountID: "test-account",
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-ssh-bidir", Name: "SSH both ways", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"group-0"}, Destinations: []string{"group-1"},
|
||||
AuthorizedGroups: map[string][]string{"ssh-users": {"root"}},
|
||||
}},
|
||||
})
|
||||
|
||||
nmSrc := componentsNetworkMap(account, "peer-0", validatedPeers)
|
||||
require.NotNil(t, nmSrc)
|
||||
assert.True(t, nmSrc.EnableSSH, "source-side peer of bidirectional SSH rule should have SSH enabled")
|
||||
require.NotEmpty(t, nmSrc.AuthorizedUsers, "source-side peer should receive authorized users from bidirectional rule")
|
||||
rootUsers, hasRoot := nmSrc.AuthorizedUsers["root"]
|
||||
require.True(t, hasRoot, "source-side peer should map the 'root' local user")
|
||||
_, hasDev := rootUsers["user-dev"]
|
||||
assert.True(t, hasDev, "source-side peer should include 'user-dev' under 'root'")
|
||||
|
||||
nmDst := componentsNetworkMap(account, "peer-10", validatedPeers)
|
||||
require.NotNil(t, nmDst)
|
||||
assert.True(t, nmDst.EnableSSH, "destination-side peer should also have SSH enabled")
|
||||
_, hasRoot = nmDst.AuthorizedUsers["root"]
|
||||
assert.True(t, hasRoot, "destination-side peer should also map the 'root' local user")
|
||||
}
|
||||
|
||||
// TestComponents_SSHLegacyImpliedSSH verifies that a non-SSH ALL protocol policy with
|
||||
// SSHEnabled peer implies legacy SSH access.
|
||||
func TestComponents_SSHLegacyImpliedSSH(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user