Compare commits

..

3 Commits

Author SHA1 Message Date
jnfrati
520370a8b0 Merge branch 'main' of github.com:netbirdio/netbird into feat/admin-cli 2026-06-22 17:17:47 +02:00
jnfrati
b5a16a1898 chore: move token commands under admin CLI 2026-06-04 12:49:48 +02:00
jnfrati
449b5cbb80 feat: add self-hosted admin CLI 2026-06-04 11:41:57 +02:00
35 changed files with 1103 additions and 1367 deletions

View File

@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
ProfileName: string(activeProf.ID),
ProfileName: activeProf.Name,
Username: currUser.Username,
})
if err != nil {

View File

@@ -27,7 +27,7 @@ type Logger struct {
wgIfaceNetV6 netip.Prefix
dnsCollection atomic.Bool
exitNodeCollection atomic.Bool
Store types.AggregatingStore
Store types.Store
}
func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix) *Logger {
@@ -35,7 +35,7 @@ func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix)
statusRecorder: statusRecorder,
wgIfaceNet: wgIfaceIPNet,
wgIfaceNetV6: wgIfaceIPNetV6,
Store: store.NewAggregatingMemoryStore(),
Store: store.NewMemoryStore(),
}
}
@@ -125,10 +125,6 @@ func (l *Logger) stop() {
l.mux.Unlock()
}
func (l *Logger) ResetAggregationWindow() types.FlowEventAggregator {
return l.Store.ResetAggregationWindow()
}
func (l *Logger) GetEvents() []*types.Event {
return l.Store.GetEvents()
}

View File

@@ -9,14 +9,12 @@ import (
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
"github.com/netbirdio/netbird/client/internal/netflow/logger"
"github.com/netbirdio/netbird/client/internal/netflow/store"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/flow/client"
@@ -25,16 +23,14 @@ import (
// Manager handles netflow tracking and logging
type Manager struct {
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
receiverClient *client.GRPCClient
eventsWithoutAcks nftypes.Store
publicKey []byte
cancel context.CancelFunc
retryInterval time.Duration
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
receiverClient *client.GRPCClient
publicKey []byte
cancel context.CancelFunc
}
// NewManager creates a new netflow manager
@@ -52,11 +48,9 @@ func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *pee
}
return &Manager{
logger: flowLogger,
conntrack: ct,
publicKey: publicKey,
retryInterval: time.Second,
eventsWithoutAcks: store.NewMemoryStore(),
logger: flowLogger,
conntrack: ct,
publicKey: publicKey,
}
}
@@ -72,7 +66,6 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
}
// enableFlow starts components for flow tracking
// must be called under m.mux lock
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
// first make sender ready so events don't pile up
if m.needsNewClient(previous) {
@@ -92,7 +85,6 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
return nil
}
// must be called under m.mux lock
func (m *Manager) resetClient() error {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
@@ -115,19 +107,14 @@ func (m *Manager) resetClient() error {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
m.shutdownWg.Add(3)
flowConfigInterval := m.flowConfig.Interval
m.shutdownWg.Add(2)
go func() {
defer m.shutdownWg.Done()
m.receiveACKs(ctx, flowClient, flowConfigInterval)
m.receiveACKs(ctx, flowClient)
}()
go func() {
defer m.shutdownWg.Done()
m.startSender(ctx, flowConfigInterval)
}()
go func() {
defer m.shutdownWg.Done()
m.startRetries(ctx, flowConfigInterval)
m.startSender(ctx)
}()
return nil
@@ -211,8 +198,8 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
return m.logger
}
func (m *Manager) startSender(ctx context.Context, flowConfigInterval time.Duration) {
ticker := time.NewTicker(flowConfigInterval)
func (m *Manager) startSender(ctx context.Context) {
ticker := time.NewTicker(m.flowConfig.Interval)
defer ticker.Stop()
for {
@@ -220,29 +207,27 @@ func (m *Manager) startSender(ctx context.Context, flowConfigInterval time.Durat
case <-ctx.Done():
return
case <-ticker.C:
collectedEvents := m.logger.ResetAggregationWindow()
events := collectedEvents.GetAggregatedEvents()
events := m.logger.GetEvents()
for _, event := range events {
if err := m.send(event); err != nil {
log.Errorf("failed to send flow event to server: %v", err)
} else {
log.Tracef("sent flow event: %s", event.ID)
continue
}
m.eventsWithoutAcks.StoreEvent(event)
log.Tracef("sent flow event: %s", event.ID)
}
}
}
}
func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient, flowConfigInterval time.Duration) {
err := client.Receive(ctx, flowConfigInterval, func(ack *proto.FlowEventAck) error {
func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
id, err := uuid.FromBytes(ack.EventId)
if err != nil {
log.Warnf("failed to convert ack event id to uuid: %v", err)
return nil
}
log.Tracef("received flow event ack: %s", id)
m.eventsWithoutAcks.DeleteEvents([]uuid.UUID{id})
m.logger.DeleteEvents([]uuid.UUID{id})
return nil
})
@@ -251,43 +236,6 @@ func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient, fl
}
}
// We effectively never drop events (see MaxInterval), which makes eventsWithoutAcks unbounded.
// We may want to limit the max size of the store, and start dropping oldest events when the threshold is reached.
func (m *Manager) startRetries(ctx context.Context, flowConfigInterval time.Duration) {
timer := time.NewTimer(m.retryInterval)
retryBackoff := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 1 * time.Second,
RandomizationFactor: 0.5,
Multiplier: 1.7,
MaxInterval: flowConfigInterval / 2,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return
case <-timer.C:
for _, e := range m.eventsWithoutAcks.GetEvents() {
if e.Timestamp.Add(time.Second).After(time.Now()) {
// grace period on retries to avoid early retries
// do not retry if the event is less than 1 sec old
continue
}
if err := m.send(e); err != nil {
timer = time.NewTimer(retryBackoff.NextBackOff()) //nolint:staticcheck,wastedassign
break
}
}
retryBackoff.Reset()
timer = time.NewTimer(m.retryInterval)
}
}
}
func (m *Manager) send(event *nftypes.Event) error {
m.mux.Lock()
client := m.receiverClient
@@ -302,11 +250,9 @@ func (m *Manager) send(event *nftypes.Event) error {
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
protoEvent := &proto.FlowEvent{
EventId: event.ID[:],
Timestamp: timestamppb.New(event.Timestamp),
PublicKey: publicKey,
WindowStart: timestamppb.New(event.WindowStart),
WindowEnd: timestamppb.New(event.WindowEnd),
EventId: event.ID[:],
Timestamp: timestamppb.New(event.Timestamp),
PublicKey: publicKey,
FlowFields: &proto.FlowFields{
FlowId: event.FlowID[:],
RuleId: event.RuleID,
@@ -321,9 +267,6 @@ func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
TxBytes: event.TxBytes,
SourceResourceId: event.SourceResourceID,
DestResourceId: event.DestResourceID,
NumOfStarts: event.NumOfStarts,
NumOfEnds: event.NumOfEnds,
NumOfDrops: event.NumOfDrops,
},
}

View File

@@ -1,291 +0,0 @@
package netflow
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"testing"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/flow/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
type testServer struct {
proto.UnimplementedFlowServiceServer
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
handlerDone chan struct{} // signaled each time Events() exits
handlerStarted chan struct{} // signaled each time Events() begins
}
func newTestServer(t *testing.T) *testServer {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := &testServer{
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: listener.Addr().String(),
handlerDone: make(chan struct{}, 10),
handlerStarted: make(chan struct{}, 10),
}
proto.RegisterFlowServiceServer(s.grpcSrv, s)
go func() {
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
t.Logf("server error: %v", err)
}
}()
t.Cleanup(func() {
s.grpcSrv.Stop()
})
return s
}
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
defer func() {
select {
case s.handlerDone <- struct{}{}:
default:
}
}()
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
if err != nil {
return err
}
select {
case s.handlerStarted <- struct{}{}:
default:
}
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
go func() {
defer cancel()
for {
event, err := stream.Recv()
if err != nil {
return
}
if !event.IsInitiator {
select {
case s.events <- event:
case <-ctx.Done():
return
}
}
}
}()
for {
select {
case ack := <-s.acks:
if err := stream.Send(ack); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func TestSendEventReceiveAck(t *testing.T) {
_, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
server := newTestServer(t)
manager := createManager(t, server.addr, 60*time.Second) // set high to prevent retries in this test
defer manager.Close()
assert.Eventually(t, func() bool {
select {
case <-server.handlerStarted:
return true
default:
return false
}
}, 3*time.Second, 100*time.Millisecond)
event1 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.2"),
DestPort: 2345,
Protocol: 6,
}
manager.logger.StoreEvent(event1)
event2 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.1"),
DestPort: 1234,
Protocol: 6,
}
manager.logger.StoreEvent(event2)
// verify the server received logged events
serverSideEvents := make([]*proto.FlowEvent, 0)
assert.Eventually(t, func() bool {
select {
case event := <-server.events:
serverSideEvents = append(serverSideEvents, event)
if len(serverSideEvents) == 2 {
return true
}
default:
if len(serverSideEvents) == 2 {
return true
}
}
return false
}, 5*time.Second, 100*time.Millisecond)
serverSideFlowIds := make([]uuid.UUID, 0, 2)
slices.Values(serverSideEvents)(func(e *proto.FlowEvent) bool {
id, err := uuid.FromBytes(e.FlowFields.FlowId)
assert.NoError(t, err)
serverSideFlowIds = append(serverSideFlowIds, id)
return true
})
assert.ElementsMatch(t, []uuid.UUID{event1.FlowID, event2.FlowID}, serverSideFlowIds)
// verify the manager tracks un-acked events
unackedEvents := manager.eventsWithoutAcks.GetEvents()
assert.Len(t, unackedEvents, 2)
flowIds := make([]uuid.UUID, 0)
slices.Values(unackedEvents)(func(e *types.Event) bool {
flowIds = append(flowIds, e.FlowID)
return true
})
assert.ElementsMatch(t, flowIds, []uuid.UUID{event1.FlowID, event2.FlowID})
}
// verify handling of retries:
// - unacked events are retried
// - when acks arrive, events are removed from the un-acked event tracker
func TestRetryEvents(t *testing.T) {
_, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
server := newTestServer(t)
manager := createManager(t, server.addr, time.Second) // set low to start retries sooner
defer manager.Close()
assert.Eventually(t, func() bool {
select {
case <-server.handlerStarted:
return true
default:
return false
}
}, 3*time.Second, 100*time.Millisecond)
event1 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.2"),
DestPort: 2345,
Protocol: 6,
}
manager.logger.StoreEvent(event1)
event2 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.1"),
DestPort: 1234,
Protocol: 6,
}
manager.logger.StoreEvent(event2)
// verify the server received retries of logged events
serverSideEvents := make([]*proto.FlowEvent, 0)
func() {
c := time.After(2500 * time.Millisecond)
for {
select {
case event := <-server.events:
serverSideEvents = append(serverSideEvents, event)
case <-c:
return
}
}
}()
assert.True(t, len(serverSideEvents) > 2) // must see retries
uniqueServerSideEvents := make(map[uuid.UUID]*proto.FlowEvent)
slices.Values(serverSideEvents)(func(e *proto.FlowEvent) bool {
id, err := uuid.FromBytes(e.FlowFields.FlowId)
assert.NoError(t, err)
uniqueServerSideEvents[id] = e
return true
})
assert.Contains(t, uniqueServerSideEvents, event1.FlowID)
assert.Contains(t, uniqueServerSideEvents, event2.FlowID)
// ack events
server.acks <- &proto.FlowEventAck{EventId: uniqueServerSideEvents[event1.FlowID].EventId}
server.acks <- &proto.FlowEventAck{EventId: uniqueServerSideEvents[event2.FlowID].EventId}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
unackedEvents := manager.eventsWithoutAcks.GetEvents()
assert.Empty(c, unackedEvents)
}, 3*time.Second, 100*time.Millisecond)
}
func createManager(t *testing.T, serverAddr string, retryInterval time.Duration) *Manager {
t.Helper()
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: netip.MustParsePrefix("192.168.1.1/32"),
},
isUserspaceBind: true,
}
publicKey := []byte("test-public-key")
manager := NewManager(mockIFace, publicKey, nil)
manager.retryInterval = retryInterval
initialConfig := &types.FlowConfig{
Enabled: true,
URL: fmt.Sprintf("http://%s", serverAddr),
TokenPayload: "initial-payload",
TokenSignature: "initial-signature",
Interval: 500 * time.Millisecond,
}
err := manager.Update(initialConfig)
require.NoError(t, err)
return manager
}
func ipAddr(a string) netip.Addr {
addr, _ := netip.ParseAddr(a)
return addr
}

View File

@@ -1,334 +0,0 @@
package store
import (
"math/rand"
"net/netip"
"testing"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/stretchr/testify/assert"
)
var random = rand.New(rand.NewSource(time.Now().UnixNano()))
func TestFlowAggregation(t *testing.T) {
var protocols = []types.Protocol{types.ICMP, types.ICMPv6, types.TCP, types.UDP}
var tests = []struct {
description string
addresses [][]netip.Addr
dstPort uint16
eventTypes []types.Type
}{
{
description: "start and stop",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}, {ipAddr("3.3.3.3"), ipAddr("2.2.2.2")}},
dstPort: uint16(random.Uint32() >> 16),
eventTypes: []types.Type{types.TypeStart, types.TypeEnd},
},
{
description: "start and drop",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}, {ipAddr("3.3.3.3"), ipAddr("2.2.2.2")}},
dstPort: uint16(random.Uint32() >> 16),
eventTypes: []types.Type{types.TypeStart, types.TypeDrop},
},
{
description: "start only",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}, {ipAddr("3.3.3.3"), ipAddr("2.2.2.2")}},
dstPort: uint16(random.Uint32() >> 16),
eventTypes: []types.Type{types.TypeStart},
},
{
description: "drop only",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}, {ipAddr("3.3.3.3"), ipAddr("2.2.2.2")}},
dstPort: uint16(random.Uint32() >> 16),
eventTypes: []types.Type{types.TypeDrop},
}}
for _, protocol := range protocols {
for _, tt := range tests {
t.Run(tt.description+" "+protocol.String(), func(t *testing.T) {
store := NewAggregatingMemoryStore()
store.WindowEnd = time.Now().Add(5 * time.Second)
allExpected := make([]*types.Event, 0)
for _, srcAndDst := range tt.addresses {
inEvents, expected := generateEvents(srcAndDst[0], srcAndDst[1], tt.dstPort, tt.eventTypes, protocol, types.Ingress, 0, store.WindowStart, store.WindowEnd)
for _, e := range inEvents {
store.StoreEvent(e)
}
allExpected = append(allExpected, expected)
}
events := store.GetAggregatedEvents()
assert.ElementsMatch(t, events, allExpected)
})
}
}
}
func TestIcmpEventAggregation(t *testing.T) {
var protocols = []types.Protocol{types.ICMP, types.ICMPv6}
var icmpTypes = []uint8{1, 2, 3}
var tests = []struct {
description string
addresses [][]netip.Addr
eventTypes []types.Type
}{
{
description: "start and stop",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}},
eventTypes: []types.Type{types.TypeStart, types.TypeEnd},
},
{
description: "start and drop",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}},
eventTypes: []types.Type{types.TypeStart, types.TypeDrop},
},
{
description: "start only",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}},
eventTypes: []types.Type{types.TypeStart},
},
{
description: "drop only",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}},
eventTypes: []types.Type{types.TypeDrop},
}}
for _, protocol := range protocols {
for _, tt := range tests {
t.Run(tt.description+" "+protocol.String(), func(t *testing.T) {
store := NewAggregatingMemoryStore()
store.WindowEnd = time.Now().Add(5 * time.Second)
allExpected := make([]*types.Event, 0)
for _, icmpType := range icmpTypes {
events, expected := generateEvents(tt.addresses[0][0], tt.addresses[0][1], 0, tt.eventTypes, protocol, types.Ingress, icmpType, store.WindowStart, store.WindowEnd)
for _, e := range events {
store.StoreEvent(e)
}
allExpected = append(allExpected, expected)
}
aggregatedEvents := store.GetAggregatedEvents()
assert.Len(t, aggregatedEvents, len(allExpected))
assert.ElementsMatch(t, aggregatedEvents, allExpected)
})
}
}
}
func TestFlowAggregationOfUnknownProtocols(t *testing.T) {
var tests = []struct {
description string
addresses [][]netip.Addr
dstPort uint16
eventTypes []types.Type
}{
{
description: "start and stop",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}, {ipAddr("3.3.3.3"), ipAddr("2.2.2.2")}},
dstPort: uint16(random.Uint32() >> 16),
eventTypes: []types.Type{types.TypeStart, types.TypeEnd},
},
{
description: "start and drop",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}, {ipAddr("3.3.3.3"), ipAddr("2.2.2.2")}},
dstPort: uint16(random.Uint32() >> 16),
eventTypes: []types.Type{types.TypeStart, types.TypeDrop},
},
{
description: "start only",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}, {ipAddr("3.3.3.3"), ipAddr("2.2.2.2")}},
dstPort: uint16(random.Uint32() >> 16),
eventTypes: []types.Type{types.TypeStart},
},
{
description: "drop only",
addresses: [][]netip.Addr{{ipAddr("1.1.1.1"), ipAddr("2.2.2.2")}, {ipAddr("3.3.3.3"), ipAddr("2.2.2.2")}},
dstPort: uint16(random.Uint32() >> 16),
eventTypes: []types.Type{types.TypeDrop},
}}
for _, tt := range tests {
t.Run(tt.description+" "+types.ProtocolUnknown.String(), func(t *testing.T) {
store := NewAggregatingMemoryStore()
store.WindowEnd = time.Now().Add(5 * time.Second)
allExpected := make([]*types.Event, 0)
for _, srcAndDst := range tt.addresses {
inEvents, expected := generateEventsForUnknownProtocol(srcAndDst[0], srcAndDst[1], tt.dstPort, tt.eventTypes, types.ProtocolUnknown, types.Ingress, store.WindowStart, store.WindowEnd)
for _, e := range inEvents {
store.StoreEvent(e)
}
allExpected = append(allExpected, expected...)
}
events := store.GetAggregatedEvents()
assert.ElementsMatch(t, events, allExpected)
})
}
}
func ipAddr(a string) netip.Addr {
addr, _ := netip.ParseAddr(a)
return addr
}
func generateEvents(srcIp, dstIp netip.Addr, dstPort uint16, eventTypes []types.Type, protocol types.Protocol,
direction types.Direction, icmpType uint8, windowStart, windowEnd time.Time) ([]*types.Event, *types.Event) {
var rxPackets, txPackets, rxBytes, txBytes uint64
inEvents := make([]*types.Event, 0)
ts := time.Now()
flowId := uuid.New()
srcPort := uint16(random.Uint32() >> 16)
for idx, eventType := range eventTypes {
e := &types.Event{
ID: uuid.New(),
Timestamp: ts.Add(time.Duration(idx) * time.Second),
EventFields: types.EventFields{
FlowID: flowId,
Type: eventType,
Protocol: protocol,
RuleID: []byte("rule-id-1"),
Direction: direction,
SourceIP: srcIp,
SourcePort: srcPort,
DestIP: dstIp,
DestPort: dstPort,
SourceResourceID: []byte("source-resource-id"),
DestResourceID: []byte("dest-resource-id"),
RxPackets: random.Uint64(),
TxPackets: random.Uint64(),
RxBytes: random.Uint64(),
TxBytes: random.Uint64(),
}}
rxBytes += e.RxBytes
txBytes += e.TxBytes
rxPackets += e.RxPackets
txPackets += e.TxPackets
inEvents = append(inEvents, e)
if protocol == types.ICMP || protocol == types.ICMPv6 {
e.ICMPType = icmpType
}
}
var start, end, drop uint64
for _, eventType := range eventTypes {
switch eventType {
case types.TypeStart:
start += 1
case types.TypeDrop:
drop += 1
case types.TypeEnd:
end += 1
}
}
aggregatedEvent := &types.Event{
ID: inEvents[0].ID,
Timestamp: inEvents[0].Timestamp,
WindowStart: windowStart,
WindowEnd: windowEnd,
EventFields: types.EventFields{
FlowID: flowId,
Type: types.TypeUnknown,
Protocol: inEvents[0].Protocol,
RuleID: []byte("rule-id-1"),
Direction: inEvents[0].Direction,
SourceIP: srcIp,
SourcePort: srcPort,
DestIP: dstIp,
DestPort: dstPort,
SourceResourceID: []byte("source-resource-id"),
DestResourceID: []byte("dest-resource-id"),
RxPackets: rxPackets,
TxPackets: txPackets,
RxBytes: rxBytes,
TxBytes: txBytes,
NumOfStarts: start,
NumOfEnds: end,
NumOfDrops: drop,
}}
if protocol == types.ICMP || protocol == types.ICMPv6 {
aggregatedEvent.ICMPType = icmpType
}
return inEvents, aggregatedEvent
}
func generateEventsForUnknownProtocol(srcIp, dstIp netip.Addr, dstPort uint16, eventTypes []types.Type, protocol types.Protocol,
direction types.Direction, windowStart, windowEnd time.Time) ([]*types.Event, []*types.Event) {
inEvents := make([]*types.Event, 0)
expectedEvents := make([]*types.Event, 0)
ts := time.Now()
flowId := uuid.New()
srcPort := uint16(random.Uint32() >> 16)
for idx, eventType := range eventTypes {
e := &types.Event{
ID: uuid.New(),
Timestamp: ts.Add(time.Duration(idx) * time.Second),
EventFields: types.EventFields{
FlowID: flowId,
Type: eventType,
Protocol: protocol,
RuleID: []byte("rule-id-1"),
Direction: direction,
SourceIP: srcIp,
SourcePort: srcPort,
DestIP: dstIp,
DestPort: dstPort,
SourceResourceID: []byte("source-resource-id"),
DestResourceID: []byte("dest-resource-id"),
RxPackets: random.Uint64(),
TxPackets: random.Uint64(),
RxBytes: random.Uint64(),
TxBytes: random.Uint64(),
}}
inEvents = append(inEvents, e)
var start, end, drop uint64
switch eventType {
case types.TypeStart:
start = 1
case types.TypeDrop:
drop = 1
case types.TypeEnd:
end = 1
}
expectedEvents = append(expectedEvents, &types.Event{
ID: e.ID,
Timestamp: e.Timestamp,
WindowStart: windowStart,
WindowEnd: windowEnd,
EventFields: types.EventFields{
FlowID: flowId,
Type: types.TypeUnknown,
Protocol: e.Protocol,
RuleID: []byte("rule-id-1"),
Direction: e.Direction,
SourceIP: srcIp,
SourcePort: srcPort,
DestIP: dstIp,
DestPort: dstPort,
SourceResourceID: []byte("source-resource-id"),
DestResourceID: []byte("dest-resource-id"),
RxPackets: e.RxPackets,
TxPackets: e.TxPackets,
RxBytes: e.RxBytes,
TxBytes: e.TxBytes,
NumOfStarts: start,
NumOfEnds: end,
NumOfDrops: drop,
}})
}
return inEvents, expectedEvents
}

View File

@@ -1,15 +1,10 @@
package store
import (
"maps"
"math/rand"
v2 "math/rand/v2"
"net/netip"
"slices"
"sync"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -24,13 +19,6 @@ type Memory struct {
events map[uuid.UUID]*types.Event
}
type AggregatingMemory struct {
Memory
WindowStart time.Time
WindowEnd time.Time
rnd *v2.PCG
}
func (m *Memory) StoreEvent(event *types.Event) {
m.mux.Lock()
defer m.mux.Unlock()
@@ -60,92 +48,3 @@ func (m *Memory) DeleteEvents(ids []uuid.UUID) {
delete(m.events, id)
}
}
func NewAggregatingMemoryStore() *AggregatingMemory {
return &AggregatingMemory{WindowStart: time.Now(), Memory: Memory{events: make(map[uuid.UUID]*types.Event)}, rnd: v2.NewPCG(rand.Uint64(), rand.Uint64())}
}
func (am *AggregatingMemory) ResetAggregationWindow() types.FlowEventAggregator {
am.mux.Lock()
defer am.mux.Unlock()
now := time.Now()
toret := AggregatingMemory{WindowStart: am.WindowStart, WindowEnd: now, Memory: Memory{events: am.events}}
am.events = make(map[uuid.UUID]*types.Event)
am.WindowStart = now
return &toret
}
type aggregationKey struct {
srcAddr netip.Addr
destAddr netip.Addr
destPort uint16
direction int
protocol uint8
icmpType uint8
unique uint64 // used to prevent aggregation on non icmp/udp/tcp events
}
func (am *AggregatingMemory) GetAggregatedEvents() []*types.Event {
am.mux.Lock()
defer am.mux.Unlock()
aggregated := make(map[aggregationKey]*types.Event)
for _, v := range am.events {
lookupKey := aggregationKey{srcAddr: v.SourceIP, destAddr: v.DestIP, destPort: v.DestPort, direction: int(v.Direction), protocol: uint8(v.Protocol), icmpType: v.ICMPType}
if _, ok := aggregated[lookupKey]; !ok {
event := v.Clone()
switch event.Type {
case types.TypeStart:
event.NumOfStarts += 1
case types.TypeDrop:
event.NumOfDrops += 1
case types.TypeEnd:
event.NumOfEnds += 1
}
event.Type = types.TypeUnknown
// Please note that ICMPCode field isn't propagated by the manager (see flow/proto/flow.pb.go, FlowFields struct)
// so the field value in an icmp event in the "aggregated" doesn't matter
event.WindowStart = am.WindowStart
event.WindowEnd = am.WindowEnd
if event.Protocol != types.ICMP && event.Protocol != types.ICMPv6 && event.Protocol != types.UDP && event.Protocol != types.TCP {
lookupKey.unique = am.rnd.Uint64() // to make the lookup key unique so we don't aggregate on it
}
aggregated[lookupKey] = event
continue
}
aggregatedEvent := aggregated[lookupKey]
if aggregatedEvent.Protocol != types.ICMP && aggregatedEvent.Protocol != types.ICMPv6 && aggregatedEvent.Protocol != types.UDP && aggregatedEvent.Protocol != types.TCP {
continue // we don't aggregate this type of events; shouldn't ever get here
}
// track the number of connections, duration?, open and close events?
aggregatedEvent.RxBytes += v.RxBytes
aggregatedEvent.RxPackets += v.RxPackets
aggregatedEvent.TxBytes += v.TxBytes
aggregatedEvent.TxPackets += v.TxPackets
switch v.Type {
case types.TypeStart:
aggregatedEvent.NumOfStarts += 1
case types.TypeDrop:
aggregatedEvent.NumOfDrops += 1
case types.TypeEnd:
aggregatedEvent.NumOfEnds += 1
}
if aggregatedEvent.Timestamp.Compare(v.Timestamp) > 0 {
aggregatedEvent.Timestamp = v.Timestamp
aggregatedEvent.ID = v.ID
aggregatedEvent.SourcePort = v.SourcePort
}
}
return slices.Collect(maps.Values(aggregated)) // could return an iterator instead here
}

View File

@@ -2,7 +2,6 @@ package types
import (
"net/netip"
"slices"
"strconv"
"time"
@@ -70,10 +69,8 @@ const (
)
type Event struct {
ID uuid.UUID
Timestamp time.Time
WindowStart time.Time
WindowEnd time.Time
ID uuid.UUID
Timestamp time.Time
EventFields
}
@@ -95,17 +92,6 @@ type EventFields struct {
TxPackets uint64
RxBytes uint64
TxBytes uint64
NumOfStarts uint64
NumOfEnds uint64
NumOfDrops uint64
}
func (e *Event) Clone() *Event {
toret := *e
toret.RuleID = slices.Clone(e.RuleID)
toret.SourceResourceID = slices.Clone(e.SourceResourceID)
toret.DestResourceID = slices.Clone(e.DestResourceID)
return &toret
}
type FlowConfig struct {
@@ -128,15 +114,13 @@ type FlowManager interface {
GetLogger() FlowLogger
}
type FlowEventAggregator interface {
ResetAggregationWindow() FlowEventAggregator
GetAggregatedEvents() []*Event
}
type FlowLogger interface {
ResetAggregationWindow() FlowEventAggregator
// StoreEvent stores a flow event
StoreEvent(flowEvent EventFields)
// GetEvents returns all stored events
GetEvents() []*Event
// DeleteEvents deletes events from the store
DeleteEvents([]uuid.UUID)
// Close closes the logger
Close()
// Enable enables the flow logger receiver
@@ -156,11 +140,6 @@ type Store interface {
Close()
}
type AggregatingStore interface {
FlowEventAggregator
Store
}
// ConnTracker defines the interface for connection tracking functionality
type ConnTracker interface {
// Start begins tracking connections by listening for conntrack events.

91
combined/cmd/admin.go Normal file
View File

@@ -0,0 +1,91 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newAdminCommands creates the admin command tree with combined-specific resource openers.
func newAdminCommands() *cobra.Command {
cmd := admincmd.NewCommands(withAdminResources)
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
return cmd
}
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, cfg *CombinedConfig) error {
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return fmt.Errorf("create management config: %w", err)
}
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(mgmtConfig.EmbeddedIdP)
if err != nil {
return err
}
defer func() {
if err := idpStorage.Close(); err != nil {
log.Debugf("close embedded IdP storage: %v", err)
}
}()
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
})
}
// withAdminTokenStore opens only the management store for admin token commands.
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *CombinedConfig) error {
return fn(ctx, managementStore)
})
}
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, cfg *CombinedConfig) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if dsn := cfg.Server.Store.DSN; dsn != "" {
switch strings.ToLower(cfg.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := managementStore.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, managementStore, cfg)
}

View File

@@ -64,7 +64,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
rootCmd.AddCommand(newTokenCommands())
rootCmd.AddCommand(newAdminCommands())
}
func RootCmd() *cobra.Command {

View File

@@ -1,63 +0,0 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newTokenCommands creates the token command tree with combined-specific store opener.
func newTokenCommands() *cobra.Command {
return tokencmd.NewCommands(withTokenStore)
}
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if dsn := cfg.Server.Store.DSN; dsn != "" {
switch strings.ToLower(cfg.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine)
s, err := store.NewStore(ctx, engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -134,11 +134,9 @@ type FlowEvent struct {
// When the event occurred
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// Public key of the sending peer
PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
FlowFields *FlowFields `protobuf:"bytes,4,opt,name=flow_fields,json=flowFields,proto3" json:"flow_fields,omitempty"`
IsInitiator bool `protobuf:"varint,5,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
WindowStart *timestamppb.Timestamp `protobuf:"bytes,6,opt,name=window_start,json=windowStart,proto3" json:"window_start,omitempty"`
WindowEnd *timestamppb.Timestamp `protobuf:"bytes,7,opt,name=window_end,json=windowEnd,proto3" json:"window_end,omitempty"`
PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
FlowFields *FlowFields `protobuf:"bytes,4,opt,name=flow_fields,json=flowFields,proto3" json:"flow_fields,omitempty"`
IsInitiator bool `protobuf:"varint,5,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
}
func (x *FlowEvent) Reset() {
@@ -208,20 +206,6 @@ func (x *FlowEvent) GetIsInitiator() bool {
return false
}
func (x *FlowEvent) GetWindowStart() *timestamppb.Timestamp {
if x != nil {
return x.WindowStart
}
return nil
}
func (x *FlowEvent) GetWindowEnd() *timestamppb.Timestamp {
if x != nil {
return x.WindowEnd
}
return nil
}
type FlowEventAck struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -300,6 +284,7 @@ type FlowFields struct {
// Layer 4 -specific information
//
// Types that are assignable to ConnectionInfo:
//
// *FlowFields_PortInfo
// *FlowFields_IcmpInfo
ConnectionInfo isFlowFields_ConnectionInfo `protobuf_oneof:"connection_info"`
@@ -312,9 +297,6 @@ type FlowFields struct {
// Resource ID
SourceResourceId []byte `protobuf:"bytes,14,opt,name=source_resource_id,json=sourceResourceId,proto3" json:"source_resource_id,omitempty"`
DestResourceId []byte `protobuf:"bytes,15,opt,name=dest_resource_id,json=destResourceId,proto3" json:"dest_resource_id,omitempty"`
NumOfStarts uint64 `protobuf:"varint,16,opt,name=num_of_starts,json=numOfStarts,proto3" json:"num_of_starts,omitempty"`
NumOfEnds uint64 `protobuf:"varint,17,opt,name=num_of_ends,json=numOfEnds,proto3" json:"num_of_ends,omitempty"`
NumOfDrops uint64 `protobuf:"varint,18,opt,name=num_of_drops,json=numOfDrops,proto3" json:"num_of_drops,omitempty"`
}
func (x *FlowFields) Reset() {
@@ -461,27 +443,6 @@ func (x *FlowFields) GetDestResourceId() []byte {
return nil
}
func (x *FlowFields) GetNumOfStarts() uint64 {
if x != nil {
return x.NumOfStarts
}
return 0
}
func (x *FlowFields) GetNumOfEnds() uint64 {
if x != nil {
return x.NumOfEnds
}
return 0
}
func (x *FlowFields) GetNumOfDrops() uint64 {
if x != nil {
return x.NumOfDrops
}
return 0
}
type isFlowFields_ConnectionInfo interface {
isFlowFields_ConnectionInfo()
}
@@ -618,7 +579,7 @@ var file_flow_proto_rawDesc = []byte{
0x0a, 0x0a, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x66, 0x6c,
0x6f, 0x77, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x22, 0xce, 0x02, 0x0a, 0x09, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
0x6f, 0x74, 0x6f, 0x22, 0xd4, 0x01, 0x0a, 0x09, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
0x74, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09,
0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
@@ -631,59 +592,45 @@ var file_flow_proto_rawDesc = []byte{
0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x52, 0x0a, 0x66, 0x6c,
0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x73, 0x49, 0x6e,
0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x69,
0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x3d, 0x0a, 0x0c, 0x77, 0x69,
0x6e, 0x64, 0x6f, 0x77, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b,
0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0b, 0x77, 0x69,
0x6e, 0x64, 0x6f, 0x77, 0x53, 0x74, 0x61, 0x72, 0x74, 0x12, 0x39, 0x0a, 0x0a, 0x77, 0x69, 0x6e,
0x64, 0x6f, 0x77, 0x5f, 0x65, 0x6e, 0x64, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e,
0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e,
0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x77, 0x69, 0x6e, 0x64, 0x6f,
0x77, 0x45, 0x6e, 0x64, 0x22, 0x4b, 0x0a, 0x0c, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12,
0x20, 0x0a, 0x0b, 0x69, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x02,
0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x69, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f,
0x72, 0x22, 0x82, 0x05, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73,
0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28,
0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12, 0x1e, 0x0a, 0x04, 0x74, 0x79, 0x70,
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54,
0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x07, 0x72, 0x75, 0x6c,
0x65, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x72, 0x75, 0x6c, 0x65,
0x49, 0x64, 0x12, 0x2d, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18,
0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x44, 0x69, 0x72,
0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f,
0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20,
0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1b, 0x0a,
0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x17, 0x0a, 0x07, 0x64, 0x65,
0x73, 0x74, 0x5f, 0x69, 0x70, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x64, 0x65, 0x73,
0x74, 0x49, 0x70, 0x12, 0x2d, 0x0a, 0x09, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x69, 0x6e, 0x66, 0x6f,
0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x50, 0x6f,
0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e,
0x66, 0x6f, 0x12, 0x2d, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18,
0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x49, 0x43, 0x4d,
0x50, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x49, 0x6e, 0x66,
0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x78, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x18,
0x0a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x72, 0x78, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73,
0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x78, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x0b,
0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x74, 0x78, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x12,
0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28,
0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x74, 0x78,
0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x74, 0x78,
0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f,
0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28,
0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63,
0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x72, 0x65, 0x73, 0x6f,
0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0e, 0x64,
0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x22, 0x0a,
0x0d, 0x6e, 0x75, 0x6d, 0x5f, 0x6f, 0x66, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x73, 0x18, 0x10,
0x20, 0x01, 0x28, 0x04, 0x52, 0x0b, 0x6e, 0x75, 0x6d, 0x4f, 0x66, 0x53, 0x74, 0x61, 0x72, 0x74,
0x73, 0x12, 0x1e, 0x0a, 0x0b, 0x6e, 0x75, 0x6d, 0x5f, 0x6f, 0x66, 0x5f, 0x65, 0x6e, 0x64, 0x73,
0x18, 0x11, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x6e, 0x75, 0x6d, 0x4f, 0x66, 0x45, 0x6e, 0x64,
0x73, 0x12, 0x20, 0x0a, 0x0c, 0x6e, 0x75, 0x6d, 0x5f, 0x6f, 0x66, 0x5f, 0x64, 0x72, 0x6f, 0x70,
0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0a, 0x6e, 0x75, 0x6d, 0x4f, 0x66, 0x44, 0x72,
0x6f, 0x70, 0x73, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f,
0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22, 0x4b, 0x0a, 0x0c, 0x46, 0x6c,
0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76,
0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76,
0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x69,
0x61, 0x74, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x69, 0x73, 0x49, 0x6e,
0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22, 0x9c, 0x04, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77,
0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12,
0x1e, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e,
0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12,
0x17, 0x0a, 0x07, 0x72, 0x75, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x06, 0x72, 0x75, 0x6c, 0x65, 0x49, 0x64, 0x12, 0x2d, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65,
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0f, 0x2e, 0x66, 0x6c,
0x6f, 0x77, 0x2e, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x64, 0x69,
0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x63, 0x6f, 0x6c, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70,
0x18, 0x06, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70,
0x12, 0x17, 0x0a, 0x07, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x70, 0x18, 0x07, 0x20, 0x01, 0x28,
0x0c, 0x52, 0x06, 0x64, 0x65, 0x73, 0x74, 0x49, 0x70, 0x12, 0x2d, 0x0a, 0x09, 0x70, 0x6f, 0x72,
0x74, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66,
0x6c, 0x6f, 0x77, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08,
0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x2d, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70,
0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66, 0x6c,
0x6f, 0x77, 0x2e, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08, 0x69,
0x63, 0x6d, 0x70, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x78, 0x5f, 0x70, 0x61,
0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x72, 0x78, 0x50,
0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x78, 0x5f, 0x70, 0x61, 0x63,
0x6b, 0x65, 0x74, 0x73, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x74, 0x78, 0x50, 0x61,
0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65,
0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73,
0x12, 0x19, 0x0a, 0x08, 0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01,
0x28, 0x04, 0x52, 0x07, 0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x73,
0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69,
0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52,
0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x73,
0x74, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63,
0x65, 0x49, 0x64, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f,
0x6e, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e,
0x66, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72,
0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50,
@@ -736,19 +683,17 @@ var file_flow_proto_goTypes = []interface{}{
var file_flow_proto_depIdxs = []int32{
7, // 0: flow.FlowEvent.timestamp:type_name -> google.protobuf.Timestamp
4, // 1: flow.FlowEvent.flow_fields:type_name -> flow.FlowFields
7, // 2: flow.FlowEvent.window_start:type_name -> google.protobuf.Timestamp
7, // 3: flow.FlowEvent.window_end:type_name -> google.protobuf.Timestamp
0, // 4: flow.FlowFields.type:type_name -> flow.Type
1, // 5: flow.FlowFields.direction:type_name -> flow.Direction
5, // 6: flow.FlowFields.port_info:type_name -> flow.PortInfo
6, // 7: flow.FlowFields.icmp_info:type_name -> flow.ICMPInfo
2, // 8: flow.FlowService.Events:input_type -> flow.FlowEvent
3, // 9: flow.FlowService.Events:output_type -> flow.FlowEventAck
9, // [9:10] is the sub-list for method output_type
8, // [8:9] is the sub-list for method input_type
8, // [8:8] is the sub-list for extension type_name
8, // [8:8] is the sub-list for extension extendee
0, // [0:8] is the sub-list for field type_name
0, // 2: flow.FlowFields.type:type_name -> flow.Type
1, // 3: flow.FlowFields.direction:type_name -> flow.Direction
5, // 4: flow.FlowFields.port_info:type_name -> flow.PortInfo
6, // 5: flow.FlowFields.icmp_info:type_name -> flow.ICMPInfo
2, // 6: flow.FlowService.Events:input_type -> flow.FlowEvent
3, // 7: flow.FlowService.Events:output_type -> flow.FlowEventAck
7, // [7:8] is the sub-list for method output_type
6, // [6:7] is the sub-list for method input_type
6, // [6:6] is the sub-list for extension type_name
6, // [6:6] is the sub-list for extension extendee
0, // [0:6] is the sub-list for field type_name
}
func init() { file_flow_proto_init() }

View File

@@ -24,9 +24,6 @@ message FlowEvent {
FlowFields flow_fields = 4;
bool isInitiator = 5;
google.protobuf.Timestamp window_start = 6;
google.protobuf.Timestamp window_end = 7;
}
message FlowEventAck {
@@ -78,9 +75,6 @@ message FlowFields {
bytes source_resource_id = 14;
bytes dest_resource_id = 15;
uint64 num_of_starts = 16;
uint64 num_of_ends = 17;
uint64 num_of_drops = 18;
}
// Flow event types

View File

@@ -1,8 +1,4 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.1
// - protoc v3.21.9
// source: flow.proto
package proto
@@ -15,19 +11,15 @@ import (
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
FlowService_Events_FullMethodName = "/flow.FlowService/Events"
)
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// FlowServiceClient is the client API for FlowService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type FlowServiceClient interface {
// Client to receiver streams of events and acknowledgements
Events(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[FlowEvent, FlowEventAck], error)
Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error)
}
type flowServiceClient struct {
@@ -38,40 +30,54 @@ func NewFlowServiceClient(cc grpc.ClientConnInterface) FlowServiceClient {
return &flowServiceClient{cc}
}
func (c *flowServiceClient) Events(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[FlowEvent, FlowEventAck], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &FlowService_ServiceDesc.Streams[0], FlowService_Events_FullMethodName, cOpts...)
func (c *flowServiceClient) Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error) {
stream, err := c.cc.NewStream(ctx, &FlowService_ServiceDesc.Streams[0], "/flow.FlowService/Events", opts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[FlowEvent, FlowEventAck]{ClientStream: stream}
x := &flowServiceEventsClient{stream}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type FlowService_EventsClient = grpc.BidiStreamingClient[FlowEvent, FlowEventAck]
type FlowService_EventsClient interface {
Send(*FlowEvent) error
Recv() (*FlowEventAck, error)
grpc.ClientStream
}
type flowServiceEventsClient struct {
grpc.ClientStream
}
func (x *flowServiceEventsClient) Send(m *FlowEvent) error {
return x.ClientStream.SendMsg(m)
}
func (x *flowServiceEventsClient) Recv() (*FlowEventAck, error) {
m := new(FlowEventAck)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// FlowServiceServer is the server API for FlowService service.
// All implementations must embed UnimplementedFlowServiceServer
// for forward compatibility.
// for forward compatibility
type FlowServiceServer interface {
// Client to receiver streams of events and acknowledgements
Events(grpc.BidiStreamingServer[FlowEvent, FlowEventAck]) error
Events(FlowService_EventsServer) error
mustEmbedUnimplementedFlowServiceServer()
}
// UnimplementedFlowServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedFlowServiceServer struct{}
// UnimplementedFlowServiceServer must be embedded to have forward compatible implementations.
type UnimplementedFlowServiceServer struct {
}
func (UnimplementedFlowServiceServer) Events(grpc.BidiStreamingServer[FlowEvent, FlowEventAck]) error {
return status.Error(codes.Unimplemented, "method Events not implemented")
func (UnimplementedFlowServiceServer) Events(FlowService_EventsServer) error {
return status.Errorf(codes.Unimplemented, "method Events not implemented")
}
func (UnimplementedFlowServiceServer) mustEmbedUnimplementedFlowServiceServer() {}
func (UnimplementedFlowServiceServer) testEmbeddedByValue() {}
// UnsafeFlowServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to FlowServiceServer will
@@ -81,22 +87,34 @@ type UnsafeFlowServiceServer interface {
}
func RegisterFlowServiceServer(s grpc.ServiceRegistrar, srv FlowServiceServer) {
// If the following call panics, it indicates UnimplementedFlowServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&FlowService_ServiceDesc, srv)
}
func _FlowService_Events_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(FlowServiceServer).Events(&grpc.GenericServerStream[FlowEvent, FlowEventAck]{ServerStream: stream})
return srv.(FlowServiceServer).Events(&flowServiceEventsServer{stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type FlowService_EventsServer = grpc.BidiStreamingServer[FlowEvent, FlowEventAck]
type FlowService_EventsServer interface {
Send(*FlowEventAck) error
Recv() (*FlowEvent, error)
grpc.ServerStream
}
type flowServiceEventsServer struct {
grpc.ServerStream
}
func (x *flowServiceEventsServer) Send(m *FlowEventAck) error {
return x.ServerStream.SendMsg(m)
}
func (x *flowServiceEventsServer) Recv() (*FlowEvent, error) {
m := new(FlowEvent)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// FlowService_ServiceDesc is the grpc.ServiceDesc for FlowService service.
// It's only intended for direct use with grpc.RegisterService,

89
management/cmd/admin.go Normal file
View File

@@ -0,0 +1,89 @@
package cmd
import (
"context"
"fmt"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var adminDatadir string
// newAdminCommands creates the admin command tree with management-specific resource openers.
func newAdminCommands() *cobra.Command {
cmd := admincmd.NewCommands(withAdminResources)
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
return cmd
}
// withAdminResources initializes logging, loads config, opens the management store
// and embedded IdP storage, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, config *nbconfig.Config) error {
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(config.EmbeddedIdP)
if err != nil {
return err
}
defer func() {
if err := idpStorage.Close(); err != nil {
log.Debugf("close embedded IdP storage: %v", err)
}
}()
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
})
}
// withAdminTokenStore opens only the management store for admin token commands.
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *nbconfig.Config) error {
return fn(ctx, managementStore)
})
}
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, config *nbconfig.Config) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if adminDatadir != "" {
oldDatadir := datadir
datadir = adminDatadir
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" {
defaultIDPFile := filepath.Join(oldDatadir, "idp.db")
if config.EmbeddedIdP.Storage.Config.File == "" || config.EmbeddedIdP.Storage.Config.File == defaultIDPFile {
config.EmbeddedIdP.Storage.Config.File = filepath.Join(datadir, "idp.db")
}
}
}
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := managementStore.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, managementStore, config)
}

View File

@@ -0,0 +1,441 @@
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
// Both the management and combined binaries use these commands, each providing
// their own opener to handle config loading and storage initialization.
package admincmd
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"strings"
"github.com/dexidp/dex/storage"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
nbdex "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const (
localConnectorID = "local"
dashboardClientID = "netbird-dashboard"
cliClientID = "netbird-cli"
defaultTOTPAuthenticatorID = "default-totp"
)
// Resources contains the storages required by the admin commands.
type Resources struct {
Store store.Store
IDPStorage storage.Storage
}
// Opener initializes command resources from the command context and calls fn.
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
type userSelector struct {
email string
userID string
}
func (s userSelector) normalized() userSelector {
return userSelector{
email: strings.TrimSpace(s.email),
userID: strings.TrimSpace(s.userID),
}
}
func (s userSelector) validate() error {
s = s.normalized()
if (s.email == "") == (s.userID == "") {
return fmt.Errorf("provide exactly one of --email or --user-id")
}
return nil
}
// NewCommands creates the admin command tree with the given resource opener.
func NewCommands(opener Opener) *cobra.Command {
adminCmd := &cobra.Command{
Use: "admin",
Short: "Self-hosted administrator helpers",
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
}
userCmd := &cobra.Command{
Use: "user",
Short: "Manage local embedded IdP users",
}
var passwordSelector userSelector
var password string
var passwordFile string
passwordCmd := &cobra.Command{
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
Aliases: []string{"set-password"},
Short: "Change a local user's password",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
if err != nil {
return err
}
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runChangePassword(ctx, resources.IDPStorage, cmd.OutOrStdout(), passwordSelector, newPassword)
})
},
}
addUserSelectorFlags(passwordCmd, &passwordSelector)
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
var resetSelector userSelector
resetMFACmd := &cobra.Command{
Use: "reset-mfa (--email email | --user-id id)",
Short: "Reset a local user's MFA enrollment",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runResetMFA(ctx, resources.IDPStorage, cmd.OutOrStdout(), resetSelector)
})
},
}
addUserSelectorFlags(resetMFACmd, &resetSelector)
userCmd.AddCommand(passwordCmd, resetMFACmd)
mfaCmd := &cobra.Command{
Use: "mfa",
Short: "Manage local MFA for embedded IdP users",
}
enableCmd := &cobra.Command{
Use: "enable",
Short: "Enable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
})
},
}
disableCmd := &cobra.Command{
Use: "disable",
Short: "Disable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
})
},
}
statusCmd := &cobra.Command{
Use: "status",
Short: "Show local MFA status",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
})
},
}
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
adminCmd.AddCommand(userCmd, mfaCmd)
return adminCmd
}
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
if cfg == nil || !cfg.Enabled {
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
}
yamlConfig, err := cfg.ToYAMLConfig()
if err != nil {
return nil, fmt.Errorf("build embedded IdP config: %w", err)
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
st, err := yamlConfig.Storage.OpenStorage(logger)
if err != nil {
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
}
return st, nil
}
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
}
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
if password != "" && passwordFile != "" {
return "", fmt.Errorf("provide only one of --password or --password-file")
}
if passwordFile == "" {
return password, nil
}
var data []byte
var err error
if passwordFile == "-" {
data, err = io.ReadAll(cmd.InOrStdin())
} else {
data, err = os.ReadFile(passwordFile)
}
if err != nil {
return "", fmt.Errorf("read password: %w", err)
}
return strings.TrimRight(string(data), "\r\n"), nil
}
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
if password == "" {
return fmt.Errorf("password is required")
}
if err := server.ValidatePassword(password); err != nil {
return fmt.Errorf("invalid password: %w", err)
}
user, err := findLocalUser(ctx, idpStorage, selector)
if err != nil {
return err
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
old.Hash = hash
return old, nil
}); err != nil {
return fmt.Errorf("update password for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
return nil
}
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
user, err := findLocalUser(ctx, idpStorage, selector)
if err != nil {
return err
}
reset := false
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
old.MFASecrets = map[string]*storage.MFASecret{}
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
return old, nil
})
if errors.Is(err, storage.ErrNotFound) {
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
return nil
}
if err != nil {
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
if reset {
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
} else {
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
}
return nil
}
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
accounts := resources.Store.GetAllAccounts(ctx)
if len(accounts) != 1 {
return fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", len(accounts))
}
settings := &types.Settings{}
if accounts[0].Settings != nil {
settings = accounts[0].Settings.Copy()
}
settings.LocalMfaEnabled = enabled
if err := resources.Store.SaveAccountSettings(ctx, accounts[0].Id, settings); err != nil {
return fmt.Errorf("save local MFA account setting: %w", err)
}
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
return err
}
state := "disabled"
if enabled {
state = "enabled"
}
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
return nil
}
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
accounts := resources.Store.GetAllAccounts(ctx)
accountStatus := "unknown"
if len(accounts) == 1 && accounts[0].Settings != nil {
accountStatus = "disabled"
if accounts[0].Settings.LocalMfaEnabled {
accountStatus = "enabled"
}
}
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
if err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
return nil
}
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector) (storage.Password, error) {
selector = selector.normalized()
if err := selector.validate(); err != nil {
return storage.Password{}, err
}
if selector.email != "" {
user, err := idpStorage.GetPassword(ctx, selector.email)
if errors.Is(err, storage.ErrNotFound) {
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
}
if err != nil {
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
}
return user, nil
}
rawUserID := selector.userID
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
rawUserID = decodedUserID
}
users, err := idpStorage.ListPasswords(ctx)
if err != nil {
return storage.Password{}, fmt.Errorf("list local users: %w", err)
}
for _, user := range users {
if user.UserID == rawUserID || user.UserID == selector.userID {
return user, nil
}
}
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
}
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
err := idpStorage.DeleteAuthSession(ctx, userID, localConnectorID)
if err == nil || errors.Is(err, storage.ErrNotFound) {
return nil
}
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
}
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
var mfaChain []string
if enabled {
mfaChain = []string{defaultTOTPAuthenticatorID}
}
for _, clientID := range []string{cliClientID, dashboardClientID} {
if err := idpStorage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = mfaChain
return old, nil
}); err != nil {
if errors.Is(err, storage.ErrNotFound) {
return fmt.Errorf("embedded IdP client %q not found; start the management server once before toggling MFA", clientID)
}
return fmt.Errorf("update MFA chain on embedded IdP client %q: %w", clientID, err)
}
}
return nil
}
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
clientIDs := []string{cliClientID, dashboardClientID}
enabledCount := 0
for _, clientID := range clientIDs {
client, err := idpStorage.GetClient(ctx, clientID)
if errors.Is(err, storage.ErrNotFound) {
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
}
if err != nil {
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
}
if hasAuthenticator(client.MFAChain, defaultTOTPAuthenticatorID) {
enabledCount++
}
}
switch enabledCount {
case 0:
return "disabled", nil
case len(clientIDs):
return "enabled", nil
default:
return "partially enabled", nil
}
}
func hasAuthenticator(chain []string, authenticatorID string) bool {
for _, id := range chain {
if id == authenticatorID {
return true
}
}
return false
}

View File

@@ -0,0 +1,160 @@
package admincmd
import (
"bytes"
"context"
"io"
"log/slog"
"strings"
"testing"
"time"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
nbdex "github.com/netbirdio/netbird/idp/dex"
)
func newTestIDPStorage(t *testing.T) storage.Storage {
t.Helper()
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
require.NoError(t, err)
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
Email: "user@example.com",
Username: "User",
UserID: "user-1",
Hash: hash,
}))
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
UserID: "user-1",
ConnectorID: localConnectorID,
MFASecrets: map[string]*storage.MFASecret{
defaultTOTPAuthenticatorID: {
AuthenticatorID: defaultTOTPAuthenticatorID,
Type: "TOTP",
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
Confirmed: true,
CreatedAt: time.Now(),
},
},
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
"webauthn": {{CredentialID: []byte("credential")}},
},
}))
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
UserID: "user-1",
ConnectorID: localConnectorID,
Nonce: "nonce",
}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: cliClientID, Name: "CLI"}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: dashboardClientID, Name: "Dashboard"}))
return st
}
func TestRunChangePassword(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!")
require.NoError(t, err)
require.Contains(t, out.String(), "Password updated")
user, err := st.GetPassword(ctx, "user@example.com")
require.NoError(t, err)
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunChangePasswordValidatesPassword(t *testing.T) {
st := newTestIDPStorage(t)
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid password")
}
func TestRunResetMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
encodedUserID := nbdex.EncodeDexUserID("user-1", localConnectorID)
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID})
require.NoError(t, err)
require.Contains(t, out.String(), "MFA reset")
identity, err := st.GetUserIdentity(ctx, "user-1", localConnectorID)
require.NoError(t, err)
require.Empty(t, identity.MFASecrets)
require.Empty(t, identity.WebAuthnCredentials)
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.MFASecrets = nil
old.WebAuthnCredentials = nil
return old, nil
}))
var out bytes.Buffer
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"})
require.NoError(t, err)
require.Contains(t, out.String(), "No MFA enrollment found")
}
func TestSetIDPClientsMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, setIDPClientsMFA(ctx, st, true))
status, err := idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "enabled", status)
require.NoError(t, setIDPClientsMFA(ctx, st, false))
status, err = idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "disabled", status)
}
func TestUserSelectorValidate(t *testing.T) {
require.NoError(t, userSelector{email: " user@example.com "}.validate())
require.NoError(t, userSelector{userID: "user-1"}.validate())
require.Error(t, userSelector{}.validate())
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
}
func TestFindLocalUserNotFound(t *testing.T) {
st := newTestIDPStorage(t)
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"})
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "not found"))
}
func TestResolvePasswordInputFromStdin(t *testing.T) {
cmd := &cobra.Command{}
cmd.SetIn(strings.NewReader("NewPass1!\n"))
password, err := resolvePasswordInput(cmd, "", "-")
require.NoError(t, err)
require.Equal(t, "NewPass1!", password)
}
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
require.Error(t, err)
}

View File

@@ -83,7 +83,7 @@ func init() {
rootCmd.AddCommand(migrationCmd)
tc := newTokenCommands()
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(tc)
ac := newAdminCommands()
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(ac)
}

View File

@@ -1,55 +0,0 @@
package cmd
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var tokenDatadir string
// newTokenCommands creates the token command tree with management-specific store opener.
func newTokenCommands() *cobra.Command {
cmd := tokencmd.NewCommands(withTokenStore)
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
return cmd
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -1205,7 +1205,7 @@ func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*pr
return nil, msg
}
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()), realIP)
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
if err != nil {
return nil, mapError(ctx, err)
}
@@ -1254,10 +1254,7 @@ func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*prot
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
for _, postureCheck := range postureChecks {
check := toProtocolCheck(postureCheck)
if check != nil {
protoChecks = append(protoChecks, check)
}
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
}
return protoChecks
@@ -1281,9 +1278,5 @@ func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
}
}
if len(protoCheck.Files) == 0 {
return nil
}
return protoCheck
}

View File

@@ -1889,12 +1889,12 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
// concurrent stream that started earlier loses the optimistic-lock race
// in MarkPeerConnected and bails without writing.
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP}, accountID)
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
if err := am.MarkPeerConnected(ctx, peerPubKey, accountID, syncTime.UnixNano(), netMap); err != nil {
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -1914,13 +1914,13 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
return nil
}
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
if err != nil {
return err
}
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP, UpdateAccountPeers: true}, accountID)
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil {
return err
}

View File

@@ -62,7 +62,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -123,7 +123,7 @@ type Manager interface {
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)

View File

@@ -1323,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, accountID, sessionStartedAt, nmap)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, accountID, sessionStartedAt, nmap)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
}
// MarkPeerDisconnected mocks base method.
@@ -1586,17 +1586,17 @@ func (mr *MockManagerMockRecorder) SyncPeer(ctx, sync, accountID interface{}) *g
}
// SyncPeerMeta mocks base method.
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta, realIP net.IP) error {
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta, realIP)
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta)
ret0, _ := ret[0].(error)
return ret0
}
// SyncPeerMeta indicates an expected call of SyncPeerMeta.
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta, realIP interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta, realIP)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta)
}
// SyncUserJWTGroups mocks base method.

View File

@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1935,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1956,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1980,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node2SyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1990,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
"SessionStartedAt should equal node2SyncTime token")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node1StaleSyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -2052,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, token, nil)
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
}()
}
@@ -2093,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}

View File

@@ -39,7 +39,7 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
@@ -114,7 +114,7 @@ type MockAccountManager struct {
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
@@ -345,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, accountID, sessionStartedAt, nmap)
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
@@ -975,9 +975,9 @@ func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId str
}
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
if am.SyncPeerMetaFunc != nil {
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta, realIP)
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
}
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
}

View File

@@ -74,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
//
// Disconnects use MarkPeerDisconnected and require the session to match
// exactly; see PeerStatus.SessionStartedAt for the protocol.
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
start := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
@@ -102,6 +102,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
if am.geo != nil && realIP != nil {
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
}
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
return err
}
@@ -191,28 +195,24 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
return nil
}
// resolvePeerLocation looks up the geo location for realIP, returning nil when
// there is nothing to apply: geo disabled, no real IP, the IP is unchanged from
// what the peer already has, or the lookup failed. Geo lookups are skipped on
// same-IP reconnects since they are comparatively expensive. The returned value
// is applied by Peer.UpdateMetaIfNew so the change is persisted by its peer save.
func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *nbpeer.Peer, realIP net.IP) *nbpeer.Location {
if am.geo == nil || realIP == nil {
return nil
}
// updatePeerLocationIfChanged refreshes the geolocation on a separate
// row update, only when the connection IP actually changed. Geo lookups
// are expensive so we skip same-IP reconnects.
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
return nil
return
}
location, err := am.geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
return nil
return
}
return &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: location.Country.ISOCode,
CityName: location.City.Names.En,
GeoNameID: location.City.GeonameID,
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
}
}
@@ -980,8 +980,7 @@ func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
var peer *nbpeer.Peer
var ipv6CapabilityChanged bool
var metaDiff nbpeer.MetaDiff
var updated, versionChanged, ipv6CapabilityChanged bool
var err error
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1011,10 +1010,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
newLocation := am.resolvePeerLocation(ctx, peer, sync.RealIP)
metaDiff = peer.UpdateMetaIfNew(ctx, sync.Meta, newLocation)
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if metaDiff.Updated() {
if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
@@ -1042,10 +1040,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
metaDiffAffectsPosture := posture.AffectsPosture(&metaDiff, resPostureChecks)
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged || metaDiff.Hostname {
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1062,8 +1059,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
// metadata change that flips a posture result removes this peer from others'
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
// back to the resolver.
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaChangeAffectedPosture bool) []string {
if peerNotValid || metaChangeAffectedPosture {
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
if peerNotValid || (metaUpdated && hasPostureChecks) {
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
}
return affectedPeerIDsFromNetworkMap(nmap, peerID)
@@ -1173,7 +1170,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
peer.Meta = login.Meta
peer.UpdateMetaIfNew(ctx, login.Meta)
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {

View File

@@ -256,18 +256,14 @@ func (p *Peer) Copy() *Peer {
}
}
// UpdateMetaIfNew updates peer's system metadata and connection geo location if
// new information is provided. newLocation is the geo location resolved from the
// peer's current connection IP, or nil when there is nothing to apply (geo
// disabled, no real IP, or the IP is unchanged); the caller owns the expensive
// lookup and the same-IP guard. It returns a MetaDiff describing what changed;
// diff.Updated() reports whether the peer needs to be persisted.
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLocation *Location) MetaDiff {
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
if meta.isEmpty() {
return MetaDiff{}
return updated, versionChanged
}
versionChanged := p.Meta.WtVersion != meta.WtVersion
versionChanged = p.Meta.WtVersion != meta.WtVersion
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
@@ -276,177 +272,97 @@ func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLoca
oldVersion := p.Meta.WtVersion
diff := diffMeta(p.Meta, meta)
if diff.Any() {
diff := metaDiff(p.Meta, meta)
if len(diff) != 0 {
p.Meta = meta
}
diff.VersionChanged = versionChanged
locationInfo := ""
if newLocation != nil {
p.Location = *newLocation
diff.LocationChanged = true
locationInfo = fmt.Sprintf("location changed to %s, ", newLocation.ConnectionIP)
updated = true
}
versionInfo := ""
if diff.VersionChanged {
if versionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if diff.Any() || diff.VersionChanged || diff.LocationChanged {
if len(diff) > 0 || versionChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%s%d field(s) changed: %s", versionInfo, locationInfo, len(diff.Changed), strings.Join(diff.Changed, ", "))
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
}
return diff
}
// MetaDiff records which PeerSystemMeta fields differ between two metas. Each bool
// maps to a single struct field, except Environment, which is split into Cloud and
// Platform. Changed holds the human-readable `field: <old> -> <new>` entries so the
// existing log line and isEqual can be derived from the same comparison.
//
// VersionChanged and LocationChanged sit outside the per-meta-field set:
// VersionChanged tracks the WireGuard client version specifically (compared before
// the UIVersion fixup, to signal client upgrades) and LocationChanged tracks the
// peer's connection geo location, which lives on Peer rather than PeerSystemMeta.
// Neither contributes an entry to Changed, so the field-coverage accounting stays
// driven purely by the PeerSystemMeta comparison.
type MetaDiff struct {
Hostname bool
GoOS bool
Kernel bool
KernelVersion bool
Core bool
Platform bool
OS bool
OSVersion bool
WtVersion bool
UIVersion bool
SystemSerialNumber bool
SystemProductName bool
SystemManufacturer bool
EnvironmentCloud bool
EnvironmentPlatform bool
Flags bool
Capabilities bool
NetworkAddresses bool
Files bool
VersionChanged bool
LocationChanged bool
Changed []string
}
// Any reports whether any PeerSystemMeta field changed.
func (d MetaDiff) Any() bool {
return len(d.Changed) != 0
}
// Updated reports whether the peer needs to be persisted: any meta field changed
// or the geo location changed. The version flag alone does not imply a write,
// since a version change is also reflected in the WtVersion meta field.
func (d MetaDiff) Updated() bool {
return d.Any() || d.LocationChanged || d.VersionChanged
return updated, versionChanged
}
// metaDiff returns a human-readable list of the fields that differ between the
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
// source of truth for meta comparison: isEqual reports equality as an empty
// diff, so the log line can never disagree with the change decision. Slices are
// cloned before sorting, so callers' meta is not mutated.
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
return diffMeta(oldMeta, newMeta).Changed
}
// diffMeta compares two metas field by field, returning both a per-field flag set
// (for callers that need to know exactly what changed, e.g. matching against
// posture checks) and the human-readable Changed list. It is the single source of
// truth for meta comparison: isEqual reports equality as an empty diff, so the log
// line, the change decision, and the flags can never disagree.
func diffMeta(oldMeta, newMeta PeerSystemMeta) MetaDiff {
var d MetaDiff
var diff []string
add := func(field string, oldVal, newVal any) {
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
d.Hostname = true
add("hostname", oldMeta.Hostname, newMeta.Hostname)
}
if oldMeta.GoOS != newMeta.GoOS {
d.GoOS = true
add("goos", oldMeta.GoOS, newMeta.GoOS)
}
if oldMeta.Kernel != newMeta.Kernel {
d.Kernel = true
add("kernel", oldMeta.Kernel, newMeta.Kernel)
}
if oldMeta.KernelVersion != newMeta.KernelVersion {
d.KernelVersion = true
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
}
if oldMeta.Core != newMeta.Core {
d.Core = true
add("core", oldMeta.Core, newMeta.Core)
}
if oldMeta.Platform != newMeta.Platform {
d.Platform = true
add("platform", oldMeta.Platform, newMeta.Platform)
}
if oldMeta.OS != newMeta.OS {
d.OS = true
add("os", oldMeta.OS, newMeta.OS)
}
if oldMeta.OSVersion != newMeta.OSVersion {
d.OSVersion = true
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
}
if oldMeta.WtVersion != newMeta.WtVersion {
d.WtVersion = true
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
}
if oldMeta.UIVersion != newMeta.UIVersion {
d.UIVersion = true
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
}
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
d.SystemSerialNumber = true
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
}
if oldMeta.SystemProductName != newMeta.SystemProductName {
d.SystemProductName = true
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
}
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
d.SystemManufacturer = true
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
}
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
d.EnvironmentCloud = true
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
}
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
d.EnvironmentPlatform = true
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
}
if !oldMeta.Flags.isEqual(newMeta.Flags) {
d.Flags = true
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
}
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
d.Capabilities = true
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
d.NetworkAddresses = true
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
d.Files = true
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
return d
return diff
}
// sameMultiset reports whether two slices contain the same elements with the

View File

@@ -7,7 +7,6 @@ import (
"regexp"
"github.com/hashicorp/go-version"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
@@ -52,34 +51,6 @@ type Checks struct {
Checks ChecksDefinition `gorm:"serializer:json"`
}
// AffectsPosture reports whether the peer metadata changes described by diff can
// alter the outcome of any of the given posture checks. It maps each check kind to
// the metadata fields it inspects, so an unrelated change (e.g. a hostname update)
// does not force a posture re-evaluation.
func AffectsPosture(diff *nbpeer.MetaDiff, checks []*Checks) bool {
if diff == nil {
return false
}
for _, c := range checks {
if c.Checks.ProcessCheck != nil && diff.Files {
return true
}
if c.Checks.OSVersionCheck != nil && (diff.OSVersion || diff.OS || diff.KernelVersion) {
return true
}
if c.Checks.NBVersionCheck != nil && diff.WtVersion {
return true
}
if c.Checks.GeoLocationCheck != nil && diff.LocationChanged {
return true
}
if c.Checks.PeerNetworkRangeCheck != nil && diff.NetworkAddresses {
return true
}
}
return false
}
// ChecksDefinition contains definition of actual check
type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"`

View File

@@ -581,6 +581,28 @@ func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accoun
return result.RowsAffected > 0, nil
}
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
}
return nil
}
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
result := s.db.Model(&nbpeer.Peer{}).

View File

@@ -618,6 +618,56 @@ func TestSqlStore_SavePeerStatus(t *testing.T) {
assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal")
}
func TestSqlStore_SavePeerLocation(t *testing.T) {
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
peer := &nbpeer.Peer{
AccountID: account.Id,
ID: "testpeer",
Location: nbpeer.Location{
ConnectionIP: net.ParseIP("0.0.0.0"),
CountryCode: "YY",
CityName: "City",
GeoNameID: 1,
},
CreatedAt: time.Now().UTC(),
Meta: nbpeer.PeerSystemMeta{},
}
// error is expected as peer is not in store yet
err = store.SavePeerLocation(context.Background(), account.Id, peer)
assert.Error(t, err)
account.Peers[peer.ID] = peer
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
peer.Location.CountryCode = "DE"
peer.Location.CityName = "Berlin"
peer.Location.GeoNameID = 2950159
err = store.SavePeerLocation(context.Background(), account.Id, account.Peers[peer.ID])
assert.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual := account.Peers[peer.ID].Location
assert.Equal(t, peer.Location, actual)
peer.ID = "non-existing-peer"
err = store.SavePeerLocation(context.Background(), account.Id, peer)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func Test_TestGetAccountByPrivateDomain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")

View File

@@ -185,6 +185,7 @@ type Store interface {
// recorded by the database. Returns true when the update happened,
// false when a newer session has taken over.
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error

View File

@@ -2968,6 +2968,20 @@ func (mr *MockStoreMockRecorder) SavePeer(ctx, accountID, peer interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeer", reflect.TypeOf((*MockStore)(nil).SavePeer), ctx, accountID, peer)
}
// SavePeerLocation mocks base method.
func (m *MockStore) SavePeerLocation(ctx context.Context, accountID string, peer *peer.Peer) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SavePeerLocation", ctx, accountID, peer)
ret0, _ := ret[0].(error)
return ret0
}
// SavePeerLocation indicates an expected call of SavePeerLocation.
func (mr *MockStoreMockRecorder) SavePeerLocation(ctx, accountID, peer interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerLocation", reflect.TypeOf((*MockStore)(nil).SavePeerLocation), ctx, accountID, peer)
}
// SavePeerStatus mocks base method.
func (m *MockStore) SavePeerStatus(ctx context.Context, accountID, peerID string, status peer.PeerStatus) error {
m.ctrl.T.Helper()

View File

@@ -12,9 +12,6 @@ type PeerSync struct {
WireGuardPubKey string
// Meta is the system information passed by peer, must be always present
Meta nbpeer.PeerSystemMeta
// RealIP is the peer's connection IP, used to refresh its geo location.
// May be nil when the request has no associated connection IP.
RealIP net.IP
// UpdateAccountPeers indicate updating account peers,
// which occurs when the peer's metadata is updated
UpdateAccountPeers bool

View File

@@ -1847,12 +1847,17 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
const minPasswordLength = 8
// validatePassword checks password strength requirements:
// validatePassword checks password strength requirements.
func validatePassword(password string) error {
return ValidatePassword(password)
}
// ValidatePassword checks password strength requirements:
// - Minimum 8 characters
// - At least 1 digit
// - At least 1 uppercase letter
// - At least 1 special character
func validatePassword(password string) error {
func ValidatePassword(password string) error {
if len(password) < minPasswordLength {
return errors.New("password must be at least 8 characters long")
}

View File

@@ -2765,28 +2765,6 @@ components:
type: integer
description: "Number of packets transmitted."
example: 5
num_of_starts:
type: integer
description: "Number of start events."
example: 3
num_of_ends:
type: integer
description: "Number of end events."
example: 4
num_of_drops:
type: integer
description: "Number of drop events."
example: 5
window_start:
type: string
format: date-time
description: Timestamp of the start of the aggregation window.
example: 2025-03-20T16:23:58.125397Z
window_end:
type: string
format: date-time
description: Timestamp of the end of the aggregation window.
example: 2025-03-20T16:23:58.125397Z
events:
type: array
description: "List of events that are correlated to this flow (e.g., start, end)."
@@ -2808,11 +2786,6 @@ components:
- rx_packets
- tx_bytes
- tx_packets
- num_of_starts
- num_of_ends
- num_of_drops
- window_start
- window_end
- events
NetworkTrafficEventsResponse:
type: object

View File

@@ -2905,18 +2905,9 @@ type NetworkTrafficEvent struct {
Events []NetworkTrafficSubEvent `json:"events"`
// FlowId FlowID is the ID of the connection flow. Not unique because it can be the same for multiple events (e.g., start and end of the connection).
FlowId string `json:"flow_id"`
Icmp NetworkTrafficICMP `json:"icmp"`
// NumOfDrops Number of drop events.
NumOfDrops int `json:"num_of_drops"`
// NumOfEnds Number of end events.
NumOfEnds int `json:"num_of_ends"`
// NumOfStarts Number of start events.
NumOfStarts int `json:"num_of_starts"`
Policy NetworkTrafficPolicy `json:"policy"`
FlowId string `json:"flow_id"`
Icmp NetworkTrafficICMP `json:"icmp"`
Policy NetworkTrafficPolicy `json:"policy"`
// Protocol Protocol is the protocol of the traffic (e.g. 1 = ICMP, 6 = TCP, 17 = UDP, etc.).
Protocol int `json:"protocol"`
@@ -2937,12 +2928,6 @@ type NetworkTrafficEvent struct {
// TxPackets Number of packets transmitted.
TxPackets int `json:"tx_packets"`
User NetworkTrafficUser `json:"user"`
// WindowEnd Timestamp of the end of the aggregation window.
WindowEnd time.Time `json:"window_end"`
// WindowStart Timestamp of the start of the aggregation window.
WindowStart time.Time `json:"window_start"`
}
// NetworkTrafficEventsResponse defines model for NetworkTrafficEventsResponse.