Compare commits

...

7 Commits

Author SHA1 Message Date
Maycon Santos
03fd656344 [management] Fix policy tests (#3135)
- Add firewall rule isEqual method
- Fix tests
2024-12-31 18:45:40 +01:00
Pascal Fischer
18b049cd24 [management] remove sorting from network map generation (#3126) 2024-12-31 18:10:40 +01:00
Bethuel Mmbaga
2bdb4cb44a [management] Preserve jwt groups when accessing API with PAT (#3128)
* Skip JWT group sync for token-based authentication

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-31 18:59:37 +03:00
Viktor Liu
abbdf20f65 [client] Allow inbound rosenpass port (#3109) 2024-12-31 14:08:48 +01:00
Viktor Liu
43ef64cf67 [client] Ignore case when matching domains in handler chain (#3133) 2024-12-31 14:07:21 +01:00
Pascal Fischer
18316be09a [management] add selfhosted metrics for networks (#3118) 2024-12-30 12:53:51 +01:00
Maycon Santos
1a623943c8 [management] Fix networks net map generation with posture checks (#3124) 2024-12-30 12:40:24 +01:00
21 changed files with 874 additions and 148 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd ignore_words_list: erro,clienta,hastable,iif,groupd,testin
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:

View File

@@ -197,7 +197,7 @@ func (m *Manager) AllowNetbird() error {
} }
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
net.ParseIP("0.0.0.0"), net.IP{0, 0, 0, 0},
"all", "all",
nil, nil,
nil, nil,

View File

@@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
origPattern := pattern origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.") isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard { if isWildcard {
pattern = pattern[2:] pattern = pattern[2:]
} }
pattern = dns.Fqdn(pattern)
origPattern = dns.Fqdn(origPattern)
// First remove any existing handler with same original pattern and priority // First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority { if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil { if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop() c.handlers[i].StopHandler.stop()
} }
@@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
pattern = dns.Fqdn(pattern) pattern = dns.Fqdn(pattern)
// Find and remove handlers matching both original pattern and priority // Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if entry.OrigPattern == pattern && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil { if entry.StopHandler != nil {
entry.StopHandler.stop() entry.StopHandler.stop()
} }
@@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
pattern = dns.Fqdn(pattern) pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers { for _, entry := range c.handlers {
if entry.Pattern == pattern { if strings.EqualFold(entry.Pattern, pattern) {
return true return true
} }
} }
@@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
qname := r.Question[0].Name qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname) log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock() c.mu.RLock()
@@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants subdomain matching, allow suffix match // If handler wants subdomain matching, allow suffix match
// Otherwise require exact match // Otherwise require exact match
if entry.MatchSubdomains { if entry.MatchSubdomains {
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else { } else {
matched = qname == entry.Pattern matched = strings.EqualFold(qname, entry.Pattern)
} }
} }

View File

@@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 4: Remove last handler // Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault) chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain)) assert.False(t, chain.HasHandlers(testDomain))
} }
func TestHandlerChain_CaseSensitivity(t *testing.T) {
tests := []struct {
name string
scenario string
addHandlers []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}
query string
expectedCalls int
}{
{
name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
},
query: "sub.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "ExAmPlE.cOm.",
expectedCalls: 1,
},
{
name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, true, true},
},
query: "SUB.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{".", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
},
query: "example.com.",
expectedCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called
// Add handlers according to test case
for _, h := range tt.addHandlers {
var handler dns.Handler
pattern := h.pattern // capture pattern for closure
if h.subdomains {
subHandler := &nbdns.MockSubdomainHandler{
Subdomains: true,
}
if h.shouldMatch {
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = subHandler
} else {
mockHandler := &nbdns.MockHandler{}
if h.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = mockHandler
}
chain.AddHandler(pattern, handler, h.priority, nil)
}
// Execute request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r)
// Verify each handler was called exactly as expected
for _, h := range tt.addHandlers {
wasCalled := handlerCalls[h.pattern]
assert.Equal(t, h.shouldMatch, wasCalled,
"Handler for pattern %q was %s when it should%s have been",
h.pattern,
map[bool]string{true: "called", false: "not called"}[wasCalled],
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
}
// Verify total number of calls
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
"Wrong number of total handler calls")
})
}
}

View File

@@ -83,7 +83,7 @@ func (h *Manager) allowDNSFirewall() error {
IsRange: false, IsRange: false,
Values: []int{ListenPort}, Values: []int{ListenPort},
} }
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err return err

View File

@@ -406,13 +406,9 @@ func (e *Engine) Start() error {
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
if err != nil { if err != nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
} } else if e.firewall != nil {
if err := e.initFirewall(err); err != nil {
if e.firewall != nil && e.firewall.IsServerRouteSupported() { return err
err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
} }
} }
@@ -455,6 +451,41 @@ func (e *Engine) Start() error {
return nil return nil
} }
func (e *Engine) initFirewall(error) error {
if e.firewall.IsServerRouteSupported() {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
}
}
if e.rpManager == nil || !e.config.RosenpassEnabled {
return nil
}
rosenpassPort := e.rpManager.GetAddress().Port
port := manager.Port{Values: []int{rosenpassPort}}
// this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering(
net.IP{0, 0, 0, 0},
manager.ProtocolUDP,
nil,
&port,
manager.RuleDirectionIN,
manager.ActionAccept,
"",
"",
); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil
}
log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort)
return nil
}
// modifyPeers updates peers that have been modified (e.g. IP address has been changed). // modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one. // It closes the existing connection, removes it from the peerConns map, and creates a new one.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {

View File

@@ -1252,6 +1252,12 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled. // and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
if isToken, ok := claim.(bool); ok && isToken {
return nil
}
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err

View File

@@ -2729,6 +2729,19 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("skip sync for token auth type", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
})
t.Run("empty jwt groups", func(t *testing.T) { t.Run("empty jwt groups", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
UserId: "user1", UserId: "user1",
@@ -2822,7 +2835,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
t.Run("remove all JWT groups", func(t *testing.T) { t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
UserId: "user1", UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}}, Raw: jwt.MapClaims{"groups": []interface{}{}},
@@ -2833,7 +2846,20 @@ func TestAccount_SetJWTGroups(t *testing.T) {
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
})
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user2",
Raw: jwt.MapClaims{},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
}) })
} }
@@ -3037,9 +3063,9 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
minMsPerOpCICD float64 minMsPerOpCICD float64
maxMsPerOpCICD float64 maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5, 1, 3, 3, 10}, {"Small", 50, 5, 1, 3, 3, 11},
{"Medium", 500, 100, 7, 13, 10, 70}, {"Medium", 500, 100, 7, 13, 10, 70},
{"Large", 5000, 200, 65, 80, 60, 200}, {"Large", 5000, 200, 65, 80, 60, 220},
{"Small single", 50, 10, 1, 3, 3, 70}, {"Small single", 50, 10, 1, 3, 3, 70},
{"Medium single", 500, 10, 7, 13, 10, 26}, {"Medium single", 500, 10, 7, 13, 10, 26},
{"Large 5", 5000, 15, 65, 80, 60, 200}, {"Large 5", 5000, 15, 65, 80, 60, 200},
@@ -3179,7 +3205,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
maxMsPerOpCICD float64 maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5, 107, 120, 107, 160}, {"Small", 50, 5, 107, 120, 107, 160},
{"Medium", 500, 100, 105, 140, 105, 190}, {"Medium", 500, 100, 105, 140, 105, 220},
{"Large", 5000, 200, 180, 220, 180, 350}, {"Large", 5000, 200, 180, 220, 180, 350},
{"Small single", 50, 10, 107, 120, 105, 160}, {"Small single", 50, 10, 107, 120, 105, 160},
{"Medium single", 500, 10, 105, 140, 105, 170}, {"Medium single", 500, 10, 105, 140, 105, 170},

View File

@@ -175,6 +175,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[jwtclaims.IsToken] = true
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information. // Update the current request with the new context information.

View File

@@ -22,6 +22,8 @@ const (
LastLoginSuffix = "nb_last_login" LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation // Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited" Invited = "nb_invited"
// IsToken claim indicates that auth type from the user is a token
IsToken = "is_token"
) )
// ExtractClaims Extract function type // ExtractClaims Extract function type

View File

@@ -195,6 +195,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
groups int groups int
routes int routes int
routesWithRGGroups int routesWithRGGroups int
networks int
networkResources int
networkRouters int
networkRoutersWithPG int
nameservers int nameservers int
uiClient int uiClient int
version string version string
@@ -219,6 +223,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
} }
groups += len(account.Groups) groups += len(account.Groups)
networks += len(account.Networks)
networkResources += len(account.NetworkResources)
networkRouters += len(account.NetworkRouters)
for _, router := range account.NetworkRouters {
if len(router.PeerGroups) > 0 {
networkRoutersWithPG++
}
}
routes += len(account.Routes) routes += len(account.Routes)
for _, route := range account.Routes { for _, route := range account.Routes {
if len(route.PeerGroups) > 0 { if len(route.PeerGroups) > 0 {
@@ -312,6 +326,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
metricsProperties["posture_checks"] = postureChecks metricsProperties["posture_checks"] = postureChecks
metricsProperties["groups"] = groups metricsProperties["groups"] = groups
metricsProperties["networks"] = networks
metricsProperties["network_resources"] = networkResources
metricsProperties["network_routers"] = networkRouters
metricsProperties["network_routers_with_groups"] = networkRoutersWithPG
metricsProperties["routes"] = routes metricsProperties["routes"] = routes
metricsProperties["routes_with_routing_groups"] = routesWithRGGroups metricsProperties["routes_with_routing_groups"] = routesWithRGGroups
metricsProperties["nameservers"] = nameservers metricsProperties["nameservers"] = nameservers

View File

@@ -5,6 +5,9 @@ import (
"testing" "testing"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -172,6 +175,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
}, },
}, },
}, },
Networks: []*networkTypes.Network{
{
ID: "1",
AccountID: "1",
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: "1",
AccountID: "1",
NetworkID: "1",
},
{
ID: "2",
AccountID: "1",
NetworkID: "1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: "1",
AccountID: "1",
NetworkID: "1",
},
},
}, },
} }
} }
@@ -200,6 +228,15 @@ func TestGenerateProperties(t *testing.T) {
if properties["routes"] != 2 { if properties["routes"] != 2 {
t.Errorf("expected 2 routes, got %d", properties["routes"]) t.Errorf("expected 2 routes, got %d", properties["routes"])
} }
if properties["networks"] != 1 {
t.Errorf("expected 1 networks, got %d", properties["networks"])
}
if properties["network_resources"] != 2 {
t.Errorf("expected 2 network_resources, got %d", properties["network_resources"])
}
if properties["network_routers"] != 1 {
t.Errorf("expected 1 network_routers, got %d", properties["network_routers"])
}
if properties["rules"] != 4 { if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"]) t.Errorf("expected 4 rules, got %d", properties["rules"])
} }

View File

@@ -111,6 +111,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network
NetID: route.NetID(n.Name), NetID: route.NetID(n.Name),
Description: n.Description, Description: n.Description,
Peer: peer.Key, Peer: peer.Key,
PeerID: peer.ID,
PeerGroups: nil, PeerGroups: nil,
Masquerade: router.Masquerade, Masquerade: router.Masquerade,
Metric: router.Metric, Metric: router.Metric,

View File

@@ -932,11 +932,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}{ }{
{"Small", 50, 5, 90, 120, 90, 120}, {"Small", 50, 5, 90, 120, 90, 120},
{"Medium", 500, 100, 110, 150, 120, 260}, {"Medium", 500, 100, 110, 150, 120, 260},
{"Large", 5000, 200, 800, 1390, 2500, 4600}, {"Large", 5000, 200, 800, 1700, 2500, 5000},
{"Small single", 50, 10, 90, 120, 90, 120}, {"Small single", 50, 10, 90, 120, 90, 120},
{"Medium single", 500, 10, 110, 170, 120, 200}, {"Medium single", 500, 10, 110, 170, 120, 200},
{"Large 5", 5000, 15, 1300, 2100, 5000, 7000}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
{"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000}, {"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400},
} }
log.SetOutput(io.Discard) log.SetOutput(io.Discard)

View File

@@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
"peerH", "peerH",
}, },
}, },
"GroupWorkstations": {
ID: "GroupWorkstations",
Name: "GroupWorkstations",
Peers: []string{
"peerB",
"peerA",
"peerD",
"peerE",
"peerF",
"peerG",
"peerH",
},
},
"GroupSwarm": { "GroupSwarm": {
ID: "GroupSwarm", ID: "GroupSwarm",
Name: "swarm", Name: "swarm",
@@ -127,7 +140,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Action: types.PolicyTrafficActionAccept, Action: types.PolicyTrafficActionAccept,
Sources: []string{ Sources: []string{
"GroupSwarm", "GroupSwarm",
"GroupAll", "GroupWorkstations",
}, },
Destinations: []string{ Destinations: []string{
"GroupSwarm", "GroupSwarm",
@@ -159,6 +172,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerE"])
assert.Contains(t, peers, account.Peers["peerF"]) assert.Contains(t, peers, account.Peers["peerF"])
assert.Contains(t, peers, account.Peers["peerG"])
assert.Contains(t, peers, account.Peers["peerH"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
{ {
@@ -189,21 +204,6 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Protocol: "all", Protocol: "all",
Port: "", Port: "",
}, },
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
},
{ {
PeerIP: "100.65.62.5", PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT, Direction: types.FirewallRuleDirectionOUT,
@@ -280,10 +280,16 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(epectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) for _, rule := range firewallRules {
for i := range firewallRules { contains := false
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) for _, expectedRule := range epectedFirewallRules {
if rule.IsEqual(expectedRule) {
contains = true
break
}
}
assert.True(t, contains, "rule not found in expected rules %#v", rule)
} }
}) })
} }

View File

@@ -364,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route {
} }
func toProtocolRoutes(routes []*route.Route) []*proto.Route { func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0) protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes { for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r)) protoRoutes = append(protoRoutes, toProtocolRoute(r))
} }

View File

@@ -303,55 +303,47 @@ func (a *Account) GetPeerNetworkMap(
return nm return nm
} }
func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer { func (a *Account) addNetworksRoutingPeers(
missingPeers := map[string]struct{}{} networkResourcesRoutes []*route.Route,
for _, r := range networkResourcesRoutes { peer *nbpeer.Peer,
if r.Peer == peer.Key { peersToConnect []*nbpeer.Peer,
continue expiredPeers []*nbpeer.Peer,
} isRouter bool,
sourcePeers map[string]struct{},
) []*nbpeer.Peer {
missing := true networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
for _, p := range slices.Concat(peersToConnect, expiredPeers) { for _, r := range networkResourcesRoutes {
if r.Peer == p.Key { networkRoutesPeers[r.PeerID] = struct{}{}
missing = false
break
}
}
if missing {
missingPeers[r.Peer] = struct{}{}
}
} }
if isRouter { delete(sourcePeers, peer.ID)
for _, s := range sourcePeers {
if s == peer.ID {
continue
}
missing := true for _, existingPeer := range peersToConnect {
for _, p := range slices.Concat(peersToConnect, expiredPeers) { delete(sourcePeers, existingPeer.ID)
if s == p.ID { delete(networkRoutesPeers, existingPeer.ID)
missing = false }
break for _, expPeer := range expiredPeers {
} delete(sourcePeers, expPeer.ID)
} delete(networkRoutesPeers, expPeer.ID)
if missing { }
p, ok := a.Peers[s]
if ok { missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
missingPeers[p.Key] = struct{}{} if isRouter {
} for p := range sourcePeers {
} missingPeers[p] = struct{}{}
} }
} }
for p := range networkRoutesPeers {
missingPeers[p] = struct{}{}
}
for p := range missingPeers { for p := range missingPeers {
for _, p2 := range a.Peers { if missingPeer := a.Peers[p]; missingPeer != nil {
if p2.Key == p { peersToConnect = append(peersToConnect, missingPeer)
peersToConnect = append(peersToConnect, p2)
break
}
} }
} }
return peersToConnect return peersToConnect
} }
@@ -1045,37 +1037,32 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs // for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
for _, g := range groups { filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
group, ok := a.Groups[g] for _, p := range uniquePeerIDs {
if !ok { peer, ok := a.Peers[p]
if !ok || peer == nil {
continue continue
} }
for _, p := range group.Peers { // validate the peer based on policy posture checks applied
peer, ok := a.Peers[p] isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !ok || peer == nil { if !isValid {
continue continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
} }
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
} }
return filteredPeers, peerInGroups return filteredPeers, peerInGroups
} }
@@ -1151,7 +1138,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
continue continue
} }
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
fwRules = append(fwRules, rules...) fwRules = append(fwRules, rules...)
} }
@@ -1159,7 +1146,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
return fwRules return fwRules
} }
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
distPeersWithPolicy := make(map[string]struct{}) distPeersWithPolicy := make(map[string]struct{})
for _, id := range rule.Sources { for _, id := range rule.Sources {
group := a.Groups[id] group := a.Groups[id]
@@ -1173,7 +1160,7 @@ func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeer
} }
_, distPeer := distributionPeers[pID] _, distPeer := distributionPeers[pID]
_, valid := validatedPeersMap[pID] _, valid := validatedPeersMap[pID]
if distPeer && valid { if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
distPeersWithPolicy[pID] = struct{}{} distPeersWithPolicy[pID] = struct{}{}
} }
} }
@@ -1271,7 +1258,11 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...) for _, rule := range rules {
if len(rule.SourceRanges) > 0 {
routesFirewallRules = append(routesFirewallRules, rule)
}
}
} }
return routesFirewallRules return routesFirewallRules
@@ -1303,10 +1294,10 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
} }
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. // GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool var isRoutingPeer bool
var routes []*route.Route var routes []*route.Route
var allSourcePeers []string allSourcePeers := make(map[string]struct{}, len(a.Peers))
for _, resource := range a.NetworkResources { for _, resource := range a.NetworkResources {
var addSourcePeers bool var addSourcePeers bool
@@ -1319,23 +1310,22 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
} }
} }
addedResourceRoute := false
for _, policy := range resourcePolicies[resource.ID] { for _, policy := range resourcePolicies[resource.ID] {
for _, sourceGroup := range policy.SourceGroups() { peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
group := a.GetGroup(sourceGroup) if addSourcePeers {
if group == nil { for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id) allSourcePeers[pID] = struct{}{}
continue
} }
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// routing peer should be able to connect with all source peers // add routes for the resource if the peer is in the distribution group
if addSourcePeers { for peerId, router := range networkRoutingPeers {
allSourcePeers = append(allSourcePeers, group.Peers...) routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
} else if slices.Contains(group.Peers, peerID) {
// add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
}
} }
addedResourceRoute = true
}
if addedResourceRoute {
break
} }
} }
} }
@@ -1343,6 +1333,42 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
return isRoutingPeer, routes, allSourcePeers return isRoutingPeer, routes, allSourcePeers
} }
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID)
}
}
return dest
}
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups {
group := a.GetGroup(groupID)
if group == nil {
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
continue
}
if group.IsGroupAll() || len(groups) == 1 {
return group.Peers
}
for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{}
}
}
ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs {
ids = append(ids, peerID)
}
return ids
}
// getNetworkResources filters and returns a list of network resources associated with the given network ID. // getNetworkResources filters and returns a list of network resources associated with the given network ID.
func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource { func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource {
var resources []*resourceTypes.NetworkResource var resources []*resourceTypes.NetworkResource

View File

@@ -1,14 +1,20 @@
package types package types
import ( import (
"context"
"net"
"net/netip"
"slices"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -310,19 +316,19 @@ func Test_GetResourcePoliciesMap(t *testing.T) {
func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) {
account := setupTestAccount() account := setupTestAccount()
peer := &nbpeer.Peer{Key: "peer1"} peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"}
networkResourcesRoutes := []*route.Route{ networkResourcesRoutes := []*route.Route{
{Peer: "peer2Key"}, {Peer: "peer2Key", PeerID: "peer2"},
{Peer: "peer3Key"}, {Peer: "peer3Key", PeerID: "peer3"},
} }
peersToConnect := []*nbpeer.Peer{ peersToConnect := []*nbpeer.Peer{
{Key: "peer2Key"}, {Key: "peer2Key", ID: "peer2"},
} }
expiredPeers := []*nbpeer.Peer{ expiredPeers := []*nbpeer.Peer{
{Key: "peer4Key"}, {Key: "peer4Key", ID: "peer4"},
} }
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 2) require.Len(t, result, 2)
require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer2Key", result[0].Key)
require.Equal(t, "peer3Key", result[1].Key) require.Equal(t, "peer3Key", result[1].Key)
@@ -339,7 +345,7 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) {
} }
expiredPeers := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 1) require.Len(t, result, 1)
require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer2Key", result[0].Key)
} }
@@ -358,7 +364,7 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) {
{Key: "peer3Key"}, {Key: "peer3Key"},
} }
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 1) require.Len(t, result, 1)
require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer2Key", result[0].Key)
} }
@@ -370,6 +376,382 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
peersToConnect := []*nbpeer.Peer{} peersToConnect := []*nbpeer.Peer{}
expiredPeers := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 0) require.Len(t, result, 0)
} }
const (
accID = "accountID"
network1ID = "network1ID"
group1ID = "group1"
accNetResourcePeer1ID = "peer1"
accNetResourcePeer2ID = "peer2"
accNetResourceRouter1ID = "router1"
accNetResource1ID = "resource1ID"
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
)
var (
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
)
func getBasicAccountsWithResource() *Account {
return &Account{
Id: accID,
Peers: map[string]*nbpeer.Peer{
accNetResourcePeer1ID: {
ID: accNetResourcePeer1ID,
AccountID: accID,
Key: "peer1Key",
IP: accNetResourcePeer1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
accNetResourcePeer2ID: {
ID: accNetResourcePeer2ID,
AccountID: accID,
Key: "peer2Key",
IP: accNetResourcePeer2IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "windows",
WtVersion: "0.34.1",
KernelVersion: "4.4.0",
},
},
accNetResourceRouter1ID: {
ID: accNetResourceRouter1ID,
AccountID: accID,
Key: "router1Key",
IP: accNetResourceRouter1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
},
Groups: map[string]*Group{
group1ID: {
ID: group1ID,
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
},
},
Networks: []*networkTypes.Network{
{
ID: network1ID,
AccountID: accID,
Name: "network1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: accNetResourceRouter1ID,
NetworkID: network1ID,
AccountID: accID,
Peer: accNetResourceRouter1ID,
PeerGroups: []string{},
Masquerade: false,
Metric: 100,
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: accNetResource1ID,
AccountID: accID,
NetworkID: network1ID,
Address: "10.10.10.0/24",
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
Type: resourceTypes.NetworkResourceType("subnet"),
},
},
Policies: []*Policy{
{
ID: "policy1ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule1ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"80"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: nil,
},
},
PostureChecks: []*posture.Checks{
{
ID: accNetResourceRestrictPostureCheckID,
Name: accNetResourceRestrictPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.35.0",
},
},
},
{
ID: accNetResourceRelaxedPostureCheckID,
Name: accNetResourceRelaxedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
{
ID: accNetResourceLockedPostureCheckID,
Name: accNetResourceLockedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "7.7.7",
},
},
},
{
ID: accNetResourceLinuxPostureCheckID,
Name: accNetResourceLinuxPostureCheckID,
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "0.0.0"},
},
},
},
},
}
}
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// all peers should match the policy
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should not match any peer
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 0, "expected rules count don't match")
}
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// should allow peer1 and peer2 to match the policy
newPolicy := &Policy{
ID: "policy2ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "policy2ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
}
account.Policies = append(account.Policies, newPolicy)
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 2, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// two posture checks should match only the peers that match both checks
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}

View File

@@ -35,6 +35,15 @@ type FirewallRule struct {
Port string Port string
} }
// IsEqual checks if two firewall rules are equal.
func (r *FirewallRule) IsEqual(other *FirewallRule) bool {
return r.PeerIP == other.PeerIP &&
r.Direction == other.Direction &&
r.Action == other.Action &&
r.Protocol == other.Protocol &&
r.Port == other.Port
}
// generateRouteFirewallRules generates a list of firewall rules for a given route. // generateRouteFirewallRules generates a list of firewall rules for a given route.
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})

View File

@@ -117,9 +117,20 @@ func (p *Policy) RuleGroups() []string {
// SourceGroups returns a slice of all unique source groups referenced in the policy's rules. // SourceGroups returns a slice of all unique source groups referenced in the policy's rules.
func (p *Policy) SourceGroups() []string { func (p *Policy) SourceGroups() []string {
groups := make([]string, 0) if len(p.Rules) == 1 {
for _, rule := range p.Rules { return p.Rules[0].Sources
groups = append(groups, rule.Sources...)
} }
return groups groups := make(map[string]struct{}, len(p.Rules))
for _, rule := range p.Rules {
for _, source := range rule.Sources {
groups[source] = struct{}{}
}
}
groupIDs := make([]string, 0, len(groups))
for groupID := range groups {
groupIDs = append(groupIDs, groupID)
}
return groupIDs
} }

View File

@@ -95,6 +95,7 @@ type Route struct {
NetID NetID NetID NetID
Description string Description string
Peer string Peer string
PeerID string `gorm:"-"`
PeerGroups []string `gorm:"serializer:json"` PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType NetworkType NetworkType
Masquerade bool Masquerade bool
@@ -120,6 +121,7 @@ func (r *Route) Copy() *Route {
KeepRoute: r.KeepRoute, KeepRoute: r.KeepRoute,
NetworkType: r.NetworkType, NetworkType: r.NetworkType,
Peer: r.Peer, Peer: r.Peer,
PeerID: r.PeerID,
PeerGroups: slices.Clone(r.PeerGroups), PeerGroups: slices.Clone(r.PeerGroups),
Metric: r.Metric, Metric: r.Metric,
Masquerade: r.Masquerade, Masquerade: r.Masquerade,
@@ -146,6 +148,7 @@ func (r *Route) IsEqual(other *Route) bool {
other.KeepRoute == r.KeepRoute && other.KeepRoute == r.KeepRoute &&
other.NetworkType == r.NetworkType && other.NetworkType == r.NetworkType &&
other.Peer == r.Peer && other.Peer == r.Peer &&
other.PeerID == r.PeerID &&
other.Metric == r.Metric && other.Metric == r.Metric &&
other.Masquerade == r.Masquerade && other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled && other.Enabled == r.Enabled &&