Compare commits

...

16 Commits

Author SHA1 Message Date
Pascal Fischer
4d2c774378 refactor networm map generation 2025-03-13 14:29:59 +01:00
Pascal Fischer
ab2e3fec72 expose resource type consts 2025-03-12 13:49:24 +01:00
Hakan Sariman
47f88f7057 Refactor routeIDLookup methods to use Addr() for resolved IP operations 2025-03-11 19:43:58 +08:00
Hakan Sariman
ee33a6ed7c Refactor RemoveLocalPeerStateRoute to eliminate resourceId parameter 2025-03-11 13:19:30 +08:00
Hakan Sariman
da662cfd08 Add source and destination resource IDs to FlowFields 2025-03-11 13:12:54 +08:00
Hakan Sariman
ed2ee1ee9d Merge branch 'feature/flow' into feat/flow-resid 2025-03-11 13:08:11 +08:00
Hakan Sariman
92286b2541 Implement routeIDLookup for managing local and remote route IDs 2025-03-10 15:58:45 +08:00
Hakan Sariman
1ffe48f0d4 Add nil check in CheckRoutes to prevent potential panic 2025-03-08 12:54:33 +03:00
Hakan Sariman
a3b8a21385 Refactor CheckRoutes to return resource IDs for matching source and destination addresses 2025-03-08 12:26:53 +03:00
Hakan Sariman
86492b88c4 Refactor route handling to simplify route information and improve state management 2025-03-08 12:25:35 +03:00
Hakan Sariman
d08a629f9e Merge branch 'feature/flow' into feat/flow-resid 2025-03-08 12:18:02 +03:00
Hakan Sariman
268e3404d3 Merge branch 'feature/flow' into feat/flow-resid 2025-03-07 18:52:11 +03:00
Hakan Sariman
54d0591833 Refactor route handling to use RouteWithResourceId for improved state management 2025-03-07 18:43:49 +03:00
Hakan Sariman
de3b5c78d7 Fix nil pointer dereference in CheckRoutes method 2025-03-06 14:10:31 +03:00
Hakan Sariman
0b42f40cf6 Refactor route management to include resource IDs in state handling 2025-03-06 13:51:46 +03:00
Hakan Sariman
e7f921d787 [client] add resource id fields to netflow events 2025-03-05 20:35:52 +03:00
26 changed files with 463 additions and 192 deletions

View File

@@ -12,7 +12,7 @@ import (
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {

View File

@@ -24,7 +24,7 @@ import (
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error

View File

@@ -15,7 +15,7 @@ import (
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) { func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{

View File

@@ -31,7 +31,7 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type mocWGIface struct { type mocWGIface struct {
filter device.PacketFilter filter device.PacketFilter

View File

@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
// start flow manager right after interface creation // start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey() publicKey := e.config.WgPrivateKey.PublicKey()
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:]) e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/store"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
) )
type rcvChan chan *types.EventFields type rcvChan chan *types.EventFields
@@ -21,14 +22,16 @@ type Logger struct {
enabled atomic.Bool enabled atomic.Bool
rcvChan atomic.Pointer[rcvChan] rcvChan atomic.Pointer[rcvChan]
cancelReceiver context.CancelFunc cancelReceiver context.CancelFunc
statusRecorder *peer.Status
Store types.Store Store types.Store
} }
func New(ctx context.Context) *Logger { func New(ctx context.Context, statusRecorder *peer.Status) *Logger {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return &Logger{ return &Logger{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
statusRecorder: statusRecorder,
Store: store.NewMemoryStore(), Store: store.NewMemoryStore(),
} }
} }
@@ -80,6 +83,9 @@ func (l *Logger) startReceiver() {
EventFields: *eventFields, EventFields: *eventFields,
Timestamp: time.Now(), Timestamp: time.Now(),
} }
srcResId, dstResId := l.statusRecorder.CheckRoutes(event.SourceIP, event.DestIP, event.Direction)
event.SourceResourceID = []byte(srcResId)
event.DestResourceID = []byte(dstResId)
l.Store.StoreEvent(&event) l.Store.StoreEvent(&event)
} }
} }

View File

@@ -12,7 +12,7 @@ import (
) )
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
logger := logger.New(context.Background()) logger := logger.New(context.Background(), nil)
logger.Enable() logger.Enable()
event := types.EventFields{ event := types.EventFields{

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal/netflow/conntrack" "github.com/netbirdio/netbird/client/internal/netflow/conntrack"
"github.com/netbirdio/netbird/client/internal/netflow/logger" "github.com/netbirdio/netbird/client/internal/netflow/logger"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/flow/client" "github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/flow/proto"
) )
@@ -31,8 +32,8 @@ type Manager struct {
} }
// NewManager creates a new netflow manager // NewManager creates a new netflow manager
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte) *Manager { func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
flowLogger := logger.New(ctx) flowLogger := logger.New(ctx, statusRecorder)
var ct nftypes.ConnTracker var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
@@ -208,6 +209,8 @@ func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
TxPackets: event.TxPackets, TxPackets: event.TxPackets,
RxBytes: event.RxBytes, RxBytes: event.RxBytes,
TxBytes: event.TxBytes, TxBytes: event.TxBytes,
SourceResourceId: event.SourceResourceID,
DestResourceId: event.DestResourceID,
}, },
} }

View File

@@ -77,6 +77,8 @@ type EventFields struct {
Protocol Protocol Protocol Protocol
SourceIP netip.Addr SourceIP netip.Addr
DestIP netip.Addr DestIP netip.Addr
SourceResourceID []byte
DestResourceID []byte
SourcePort uint16 SourcePort uint16
DestPort uint16 DestPort uint16
ICMPType uint8 ICMPType uint8

View File

@@ -0,0 +1,100 @@
package peer
import (
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
type routeIDLookup struct {
localMap sync.Map
remoteMap sync.Map
resolvedIPs sync.Map
}
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) {
_, exists := r.localMap.LoadOrStore(route, resourceID)
if exists {
log.Tracef("resourceID %s already exists in local map", resourceID)
}
}
func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
r.localMap.Delete(route)
}
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
_, exists := r.remoteMap.LoadOrStore(route, resourceID)
if exists {
log.Tracef("resourceID %s already exists in remote map", resourceID)
}
}
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
r.remoteMap.Delete(route)
}
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) {
r.resolvedIPs.Store(route.Addr(), resourceID)
}
func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
r.resolvedIPs.Delete(route.Addr())
}
func (r *routeIDLookup) Lookup(src, dst netip.Addr, direction nftypes.Direction) (srcResourceID, dstResourceID string) {
// check resolved ip's first
resId, ok := r.resolvedIPs.Load(src)
if ok {
srcResourceID = resId.(string)
} else {
resId, ok := r.resolvedIPs.Load(dst)
if ok {
dstResourceID = resId.(string)
}
}
switch direction {
case nftypes.Ingress:
if srcResourceID == "" || dstResourceID == "" {
r.localMap.Range(func(key, value interface{}) bool {
if srcResourceID == "" && key.(netip.Prefix).Contains(src) {
srcResourceID = value.(string)
} else if dstResourceID == "" && key.(netip.Prefix).Contains(dst) {
dstResourceID = value.(string)
}
if srcResourceID != "" && dstResourceID != "" {
return false
}
return true
})
}
case nftypes.Egress:
if srcResourceID == "" || dstResourceID == "" {
r.remoteMap.Range(func(key, value interface{}) bool {
if srcResourceID == "" && key.(netip.Prefix).Contains(src) {
srcResourceID = value.(string)
} else if dstResourceID == "" && key.(netip.Prefix).Contains(dst) {
dstResourceID = value.(string)
}
if srcResourceID != "" && dstResourceID != "" {
return false
}
return true
})
}
}
return srcResourceID, dstResourceID
}

View File

@@ -17,6 +17,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
@@ -176,6 +177,8 @@ type Status struct {
eventQueue *EventQueue eventQueue *EventQueue
ingressGwMgr *ingressgw.Manager ingressGwMgr *ingressgw.Manager
routeIDLookup routeIDLookup
} }
// NewRecorder returns a new Status instance // NewRecorder returns a new Status instance
@@ -311,7 +314,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil return nil
} }
func (d *Status) AddPeerStateRoute(peer string, route string) error { func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -323,6 +326,14 @@ func (d *Status) AddPeerStateRoute(peer string, route string) error {
peerState.AddRoute(route) peerState.AddRoute(route)
d.peers[peer] = peerState d.peers[peer] = peerState
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
} else {
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
}
// todo: consider to make sense of this notification or not // todo: consider to make sense of this notification or not
d.notifyPeerListChanged() d.notifyPeerListChanged()
return nil return nil
@@ -340,11 +351,28 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
peerState.DeleteRoute(route) peerState.DeleteRoute(route)
d.peers[peer] = peerState d.peers[peer] = peerState
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
} else {
d.routeIDLookup.RemoveRemoteRouteID(pref)
}
// todo: consider to make sense of this notification or not // todo: consider to make sense of this notification or not
d.notifyPeerListChanged() d.notifyPeerListChanged()
return nil return nil
} }
// CheckRoutes checks if the source and destination addresses are within the same route
// and returns the resource ID of the route that contains the addresses
func (d *Status) CheckRoutes(src, dst netip.Addr, direction nftypes.Direction) (srcResId string, dstResId string) {
if d == nil {
return
}
return d.routeIDLookup.Lookup(src, dst, direction)
}
func (d *Status) UpdatePeerICEState(receivedState State) error { func (d *Status) UpdatePeerICEState(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -558,6 +586,50 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.notifyAddressChanged() d.notifyAddressChanged()
} }
// AddLocalPeerStateRoute adds a route to the local peer state
func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
d.mux.Lock()
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
}
if d.localPeer.Routes == nil {
d.localPeer.Routes = map[string]struct{}{}
}
d.localPeer.Routes[route] = struct{}{}
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
}
// RemoveLocalPeerStateRoute removes a route from the local peer state
func (d *Status) RemoveLocalPeerStateRoute(route string) {
d.mux.Lock()
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
}
delete(d.localPeer.Routes, route)
d.routeIDLookup.RemoveLocalRouteID(pref)
}
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
func (d *Status) CleanLocalPeerStateRoutes() {
d.mux.Lock()
defer d.mux.Unlock()
d.localPeer.Routes = map[string]struct{}{}
}
// CleanLocalPeerState cleans local peer status // CleanLocalPeerState cleans local peer status
func (d *Status) CleanLocalPeerState() { func (d *Status) CleanLocalPeerState() {
d.mux.Lock() d.mux.Lock()
@@ -641,7 +713,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates d.nsGroupStates = dnsStates
} }
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) { func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -650,6 +722,10 @@ func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resol
Prefixes: prefixes, Prefixes: prefixes,
ParentDomain: originalDomain, ParentDomain: originalDomain,
} }
for _, prefix := range prefixes {
d.routeIDLookup.AddResolvedIP(resourceId, prefix)
}
} }
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
@@ -660,6 +736,10 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
for k, v := range d.resolvedDomainsStates { for k, v := range d.resolvedDomainsStates {
if v.ParentDomain == domain { if v.ParentDomain == domain {
delete(d.resolvedDomainsStates, k) delete(d.resolvedDomainsStates, k)
for _, prefix := range v.Prefixes {
d.routeIDLookup.RemoveResolvedIP(prefix)
}
} }
} }
} }

View File

@@ -330,7 +330,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
c.connectEvent() c.connectEvent()
} }
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String()) err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
if err != nil { if err != nil {
return fmt.Errorf("add peer state route: %w", err) return fmt.Errorf("add peer state route: %w", err)
} }

View File

@@ -321,7 +321,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
if len(toAdd) > 0 || len(toRemove) > 0 { if len(toAdd) > 0 || len(toRemove) > 0 {
d.interceptedDomains[resolvedDomain] = newPrefixes d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes) d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
if len(toAdd) > 0 { if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",

View File

@@ -288,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes) updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes) r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes, r.route.GetResourceID())
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)

View File

@@ -103,9 +103,7 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
delete(m.routes, route.ID) delete(m.routes, route.ID)
state := m.statusRecorder.GetLocalPeerState() m.statusRecorder.RemoveLocalPeerStateRoute(route.Network.String())
delete(state.Routes, route.Network.String())
m.statusRecorder.UpdateLocalPeerState(state)
return nil return nil
} }
@@ -131,18 +129,12 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
m.routes[route.ID] = route m.routes[route.ID] = route
state := m.statusRecorder.GetLocalPeerState()
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
routeStr := route.Network.String() routeStr := route.Network.String()
if route.IsDynamic() { if route.IsDynamic() {
routeStr = route.Domains.SafeString() routeStr = route.Domains.SafeString()
} }
state.Routes[routeStr] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state) m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID())
return nil return nil
} }
@@ -164,9 +156,7 @@ func (m *serverRouter) cleanUp() {
} }
state := m.statusRecorder.GetLocalPeerState() m.statusRecorder.CleanLocalPeerStateRoutes()
state.Routes = nil
m.statusRecorder.UpdateLocalPeerState(state)
} }
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {

View File

@@ -278,6 +278,9 @@ type FlowFields struct {
// Number of bytes // Number of bytes
RxBytes uint64 `protobuf:"varint,12,opt,name=rx_bytes,json=rxBytes,proto3" json:"rx_bytes,omitempty"` RxBytes uint64 `protobuf:"varint,12,opt,name=rx_bytes,json=rxBytes,proto3" json:"rx_bytes,omitempty"`
TxBytes uint64 `protobuf:"varint,13,opt,name=tx_bytes,json=txBytes,proto3" json:"tx_bytes,omitempty"` TxBytes uint64 `protobuf:"varint,13,opt,name=tx_bytes,json=txBytes,proto3" json:"tx_bytes,omitempty"`
// Resource ID
SourceResourceId []byte `protobuf:"bytes,14,opt,name=source_resource_id,json=sourceResourceId,proto3" json:"source_resource_id,omitempty"`
DestResourceId []byte `protobuf:"bytes,15,opt,name=dest_resource_id,json=destResourceId,proto3" json:"dest_resource_id,omitempty"`
} }
func (x *FlowFields) Reset() { func (x *FlowFields) Reset() {
@@ -410,6 +413,20 @@ func (x *FlowFields) GetTxBytes() uint64 {
return 0 return 0
} }
func (x *FlowFields) GetSourceResourceId() []byte {
if x != nil {
return x.SourceResourceId
}
return nil
}
func (x *FlowFields) GetDestResourceId() []byte {
if x != nil {
return x.DestResourceId
}
return nil
}
type isFlowFields_ConnectionInfo interface { type isFlowFields_ConnectionInfo interface {
isFlowFields_ConnectionInfo() isFlowFields_ConnectionInfo()
} }
@@ -560,7 +577,7 @@ var file_flow_proto_rawDesc = []byte{
0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x22, 0x29, 0x0a, 0x0c, 0x46, 0x6c, 0x6f, 0x77, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x22, 0x29, 0x0a, 0x0c, 0x46, 0x6c, 0x6f, 0x77,
0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e,
0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e,
0x74, 0x49, 0x64, 0x22, 0xc4, 0x03, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x74, 0x49, 0x64, 0x22, 0x9c, 0x04, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c,
0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12, 0x1e, 0x0a, 0x04, 0x74, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12, 0x1e, 0x0a, 0x04, 0x74,
0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e, 0x66, 0x6c, 0x6f, 0x77,
@@ -587,31 +604,36 @@ var file_flow_proto_rawDesc = []byte{
0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0c, 0x20,
0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x19, 0x0a, 0x08,
0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07,
0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x73, 0x6f, 0x75, 0x72, 0x63,
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0e, 0x20,
0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x65, 0x73, 0x6f, 0x75,
0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x72, 0x65,
0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0c, 0x52,
0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x42,
0x50, 0x6f, 0x72, 0x74, 0x22, 0x44, 0x0a, 0x08, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x6e,
0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1f,
0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20,
0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x12,
0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01,
0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x44, 0x0a, 0x08,
0x57, 0x4e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70,
0x52, 0x54, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d,
0x10, 0x02, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f,
0x03, 0x2a, 0x3b, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f,
0x0a, 0x11, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59,
0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a,
0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x52, 0x54, 0x10, 0x01, 0x12, 0x0c, 0x0a, 0x08,
0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44, 0x10, 0x02, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59,
0x06, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x03, 0x2a, 0x3b, 0x0a, 0x09, 0x44, 0x69, 0x72,
0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x11, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54,
0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x0b, 0x0a,
0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x45, 0x47,
0x6f, 0x74, 0x6f, 0x33, 0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x77, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x06, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x12,
0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74,
0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (

View File

@@ -67,6 +67,11 @@ message FlowFields {
// Number of bytes // Number of bytes
uint64 rx_bytes = 12; uint64 rx_bytes = 12;
uint64 tx_bytes = 13; uint64 tx_bytes = 13;
// Resource ID
bytes source_resource_id = 14;
bytes dest_resource_id = 15;
} }
// Flow event types // Flow event types

View File

@@ -385,7 +385,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@@ -635,14 +636,13 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
response.NetworkMap.PeerConfig = response.PeerConfig response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) allPeers := appendRemotePeerConfig(networkMap.Peers, dnsName)
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
response.RemotePeers = allPeers response.RemotePeers = allPeers
response.NetworkMap.RemotePeers = allPeers response.NetworkMap.RemotePeers = allPeers
response.RemotePeersIsEmpty = len(allPeers) == 0 response.RemotePeersIsEmpty = len(allPeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) response.NetworkMap.OfflinePeers = appendRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRules = firewallRules
@@ -663,15 +663,18 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
return response return response
} }
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { func appendRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers { dst := make([]*proto.RemotePeerConfig, len(peers))
dst = append(dst, &proto.RemotePeerConfig{
for i, rPeer := range peers {
dst[i] = &proto.RemotePeerConfig{
WgPubKey: rPeer.Key, WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"}, AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName), Fqdn: rPeer.FQDN(dnsName),
})
} }
}
return dst return dst
} }

View File

@@ -281,7 +281,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
} }

View File

@@ -20,9 +20,9 @@ import (
type NetworkResourceType string type NetworkResourceType string
const ( const (
host NetworkResourceType = "host" Host NetworkResourceType = "host"
subnet NetworkResourceType = "subnet" Subnet NetworkResourceType = "subnet"
domain NetworkResourceType = "domain" Domain NetworkResourceType = "domain"
) )
func (p NetworkResourceType) String() string { func (p NetworkResourceType) String() string {
@@ -66,7 +66,7 @@ func NewNetworkResource(accountID, networkID, name, description, address string,
func (n *NetworkResource) ToAPIResponse(groups []api.GroupMinimum) *api.NetworkResource { func (n *NetworkResource) ToAPIResponse(groups []api.GroupMinimum) *api.NetworkResource {
addr := n.Prefix.String() addr := n.Prefix.String()
if n.Type == domain { if n.Type == Domain {
addr = n.Domain addr = n.Domain
} }
@@ -125,7 +125,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network
AccessControlGroups: nil, AccessControlGroups: nil,
} }
if n.Type == host || n.Type == subnet { if n.Type == Host || n.Type == Subnet {
r.Network = n.Prefix r.Network = n.Prefix
r.NetworkType = route.IPv4Network r.NetworkType = route.IPv4Network
@@ -134,7 +134,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network
} }
} }
if n.Type == domain { if n.Type == Domain {
domainList, err := nbDomain.FromStringList([]string{n.Domain}) domainList, err := nbDomain.FromStringList([]string{n.Domain})
if err != nil { if err != nil {
return nil return nil
@@ -157,18 +157,18 @@ func (n *NetworkResource) EventMeta(network *networkTypes.Network) map[string]an
func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, error) { func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, error) {
if prefix, err := netip.ParsePrefix(address); err == nil { if prefix, err := netip.ParsePrefix(address); err == nil {
if prefix.Bits() == 32 || prefix.Bits() == 128 { if prefix.Bits() == 32 || prefix.Bits() == 128 {
return host, "", prefix, nil return Host, "", prefix, nil
} }
return subnet, "", prefix, nil return Subnet, "", prefix, nil
} }
if ip, err := netip.ParseAddr(address); err == nil { if ip, err := netip.ParseAddr(address); err == nil {
return host, "", netip.PrefixFrom(ip, ip.BitLen()), nil return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil
} }
domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
if domainRegex.MatchString(address) { if domainRegex.MatchString(address) {
return domain, address, netip.Prefix{}, nil return Domain, address, netip.Prefix{}, nil
} }
return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain") return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain")

View File

@@ -14,15 +14,15 @@ func TestGetResourceType(t *testing.T) {
expectedPrefix netip.Prefix expectedPrefix netip.Prefix
}{ }{
// Valid host IPs // Valid host IPs
{"1.1.1.1", host, false, "", netip.MustParsePrefix("1.1.1.1/32")}, {"1.1.1.1", Host, false, "", netip.MustParsePrefix("1.1.1.1/32")},
{"1.1.1.1/32", host, false, "", netip.MustParsePrefix("1.1.1.1/32")}, {"1.1.1.1/32", Host, false, "", netip.MustParsePrefix("1.1.1.1/32")},
// Valid subnets // Valid subnets
{"192.168.1.0/24", subnet, false, "", netip.MustParsePrefix("192.168.1.0/24")}, {"192.168.1.0/24", Subnet, false, "", netip.MustParsePrefix("192.168.1.0/24")},
{"10.0.0.0/16", subnet, false, "", netip.MustParsePrefix("10.0.0.0/16")}, {"10.0.0.0/16", Subnet, false, "", netip.MustParsePrefix("10.0.0.0/16")},
// Valid domains // Valid domains
{"example.com", domain, false, "example.com", netip.Prefix{}}, {"example.com", Domain, false, "example.com", netip.Prefix{}},
{"*.example.com", domain, false, "*.example.com", netip.Prefix{}}, {"*.example.com", Domain, false, "*.example.com", netip.Prefix{}},
{"sub.example.com", domain, false, "sub.example.com", netip.Prefix{}}, {"sub.example.com", Domain, false, "sub.example.com", netip.Prefix{}},
// Invalid inputs // Invalid inputs
{"invalid", "", true, "", netip.Prefix{}}, {"invalid", "", true, "", netip.Prefix{}},
{"1.1.1.1/abc", "", true, "", netip.Prefix{}}, {"1.1.1.1/abc", "", true, "", netip.Prefix{}},

View File

@@ -83,7 +83,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@@ -418,7 +418,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
return nil, err return nil, err
} }
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -1029,7 +1029,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return nil, nil, nil, err return nil, nil, nil, err
} }
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -1140,7 +1140,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
} }
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID { if aclPeer.ID == peerID {
return peer, nil return peer, nil
@@ -1175,6 +1175,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
peersGroups := account.GetPeersGroupsMap()
groupsPolicies := account.GetGroupsPolicyMap()
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID)
if err != nil { if err != nil {
@@ -1200,7 +1202,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return return
} }
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, peersGroups, groupsPolicies, am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[p.ID] proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok { if ok {
@@ -1269,7 +1271,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return return
} }
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {

View File

@@ -934,13 +934,13 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
minMsPerOpCICD float64 minMsPerOpCICD float64
maxMsPerOpCICD float64 maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5, 90, 120, 90, 120}, // {"Small", 50, 5, 90, 120, 90, 120},
{"Medium", 500, 100, 110, 150, 120, 260}, // {"Medium", 500, 100, 110, 150, 120, 260},
{"Large", 5000, 200, 800, 1700, 2500, 5000}, // {"Large", 5000, 200, 800, 1700, 2500, 5000},
{"Small single", 50, 10, 90, 120, 90, 120}, // {"Small single", 50, 10, 90, 120, 90, 120},
{"Medium single", 500, 10, 110, 170, 120, 200}, // {"Medium single", 500, 10, 110, 170, 120, 200},
{"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
{"Extra Large", 2000, 2000, 1300, 2400, 3000, 6400}, // {"Extra Large", 5000, 2000, 1300, 2400, 3000, 6400},
} }
log.SetOutput(io.Discard) log.SetOutput(io.Discard)
@@ -948,6 +948,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
for _, bc := range benchCases { for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) { b.Run(bc.name, func(b *testing.B) {
b.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "0")
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil { if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err) b.Fatalf("Failed to setup test account manager: %v", err)

View File

@@ -158,14 +158,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers { for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
} }
}) })
t.Run("check first peer map details", func(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 7) assert.Len(t, peers, 7)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
@@ -394,7 +394,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
} }
t.Run("check first peer map", func(t *testing.T) { t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
@@ -422,7 +422,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map", func(t *testing.T) { t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
@@ -452,7 +452,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) { t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
@@ -473,7 +473,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map directional only", func(t *testing.T) { t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
@@ -670,7 +670,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check. // will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -680,7 +680,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1) assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@@ -696,7 +696,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -706,7 +706,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -721,19 +721,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -748,14 +748,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 3) assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3) assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 5) assert.Len(t, peers, 5)
// assert peers from Group Swarm // assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -225,6 +225,8 @@ func (a *Account) GetPeerNetworkMap(
validatedPeersMap map[string]struct{}, validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy, resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter, routers map[string]map[string]*routerTypes.NetworkRouter,
peersGroups map[string][]string,
groupsPolicies map[string]map[string]*Policy,
metrics *telemetry.AccountManagerMetrics, metrics *telemetry.AccountManagerMetrics,
) *NetworkMap { ) *NetworkMap {
start := time.Now() start := time.Now()
@@ -242,7 +244,7 @@ func (a *Account) GetPeerNetworkMap(
} }
} }
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap, peersGroups, groupsPolicies)
// exclude expired peers // exclude expired peers
var peersToConnect []*nbpeer.Peer var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer var expiredPeers []*nbpeer.Peer
@@ -945,9 +947,25 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer // GetPeerConnectionResources for a given peer
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) GetPeerConnectionResources(
ctx context.Context,
peerID string,
validatedPeersMap map[string]struct{},
peersGroups map[string][]string,
groupsPolicies map[string]map[string]*Policy,
) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies { groups, ok := peersGroups[peerID]
if !ok {
return nil, nil
}
for _, group := range groups {
policiesPerGroup, ok := groupsPolicies[group]
if !ok {
continue
}
for _, policy := range policiesPerGroup {
if !policy.Enabled { if !policy.Enabled {
continue continue
} }
@@ -978,6 +996,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
} }
} }
} }
}
return getAccumulatedResources() return getAccumulatedResources()
} }
@@ -987,7 +1006,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be // It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls. // generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, map[string]*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{}) peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0) rules := make([]*FirewallRule, 0)
@@ -999,7 +1018,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
all = &Group{} all = &Group{}
} }
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { return func(rule *PolicyRule, groupPeers map[string]*nbpeer.Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers) isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers { for _, peer := range groupPeers {
if peer == nil { if peer == nil {
@@ -1052,16 +1071,21 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// //
// Important: Posture checks are applicable only to source group peers, // Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs // for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) (map[string]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups) filteredPeers := make(map[string]*nbpeer.Peer)
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) for _, groupID := range groups {
for _, p := range uniquePeerIDs { group := a.GetGroup(groupID)
for _, p := range group.Peers {
peer, ok := a.Peers[p] peer, ok := a.Peers[p]
if !ok || peer == nil { if !ok || peer == nil {
continue continue
} }
if _, ok := filteredPeers[p]; ok {
continue
}
// validate the peer based on policy posture checks applied // validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid { if !isValid {
@@ -1077,7 +1101,8 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe
continue continue
} }
filteredPeers = append(filteredPeers, peer) filteredPeers[p] = peer
}
} }
return filteredPeers, peerInGroups return filteredPeers, peerInGroups
@@ -1318,7 +1343,7 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool var isRoutingPeer bool
var routes []*route.Route var routes []*route.Route
allSourcePeers := make(map[string]struct{}, len(a.Peers)) allSourcePeers := make(map[string]struct{})
for _, resource := range a.NetworkResources { for _, resource := range a.NetworkResources {
if !resource.Enabled { if !resource.Enabled {
@@ -1342,7 +1367,7 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{} allSourcePeers[pID] = struct{}{}
} }
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { } else if _, ok := peers[peerID]; ok && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// add routes for the resource if the peer is in the distribution group // add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers { for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
@@ -1358,9 +1383,9 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
return isRoutingPeer, routes, allSourcePeers return isRoutingPeer, routes, allSourcePeers
} }
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { func (a *Account) getPostureValidPeers(inputPeers map[string]struct{}, postureChecksIDs []string) []string {
var dest []string var dest []string
for _, peerID := range inputPeers { for peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID) dest = append(dest, peerID)
} }
@@ -1368,7 +1393,7 @@ func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []s
return dest return dest
} }
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) map[string]struct{} {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups { for _, groupID := range groups {
group := a.GetGroup(groupID) group := a.GetGroup(groupID)
@@ -1377,21 +1402,21 @@ func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []st
continue continue
} }
if group.IsGroupAll() || len(groups) == 1 { // if group.IsGroupAll() || len(groups) == 1 {
return group.Peers // return group.Peers
} // }
for _, peerID := range group.Peers { for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{} peerIDs[peerID] = struct{}{}
} }
} }
ids := make([]string, 0, len(peerIDs)) // ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs { // for peerID := range peerIDs {
ids = append(ids, peerID) // ids = append(ids, peerID)
} // }
return ids return peerIDs
} }
// getNetworkResources filters and returns a list of network resources associated with the given network ID. // getNetworkResources filters and returns a list of network resources associated with the given network ID.
@@ -1566,3 +1591,35 @@ func (a *Account) AddAllGroup() error {
} }
return nil return nil
} }
func (a *Account) GetPeersGroupsMap() map[string][]string {
groups := make(map[string][]string, len(a.Groups))
for _, group := range a.Groups {
for _, peerID := range group.Peers {
groups[peerID] = append(groups[peerID], group.ID)
}
}
return groups
}
func (a *Account) GetGroupsPolicyMap() map[string]map[string]*Policy {
policies := make(map[string]map[string]*Policy, len(a.Groups))
for _, policy := range a.Policies {
for _, rules := range policy.Rules {
for _, src := range rules.Sources {
if policies[src] == nil {
policies[src] = make(map[string]*Policy)
}
policies[src][policy.ID] = policy
}
for _, dest := range rules.Destinations {
if policies[dest] == nil {
policies[dest] = make(map[string]*Policy)
}
policies[dest][policy.ID] = policy
}
}
}
return policies
}