Compare commits

..

1 Commits

Author SHA1 Message Date
mlsmaycon
6a6b25b9af [misc] Enable race detector in Go test runs 2026-06-10 09:27:59 +02:00
26 changed files with 387 additions and 1964 deletions

View File

@@ -158,7 +158,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -race -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
@@ -478,7 +478,7 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert -coverprofile=coverage.txt \
go test -race -tags=devcert -coverprofile=coverage.txt \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/... ./shared/management/...

View File

@@ -1,301 +0,0 @@
package cmd
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"github.com/goccy/go-yaml"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
)
const (
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
)
var kubernetesCmd = &cobra.Command{
Use: "kubernetes",
Short: "Kubernetes cluster commands.",
Long: "Kubernetes cluster commands.",
}
var kubernetesListCmd = &cobra.Command{
Use: "list",
RunE: kubernetesList,
Short: "List Kubernetes clusters.",
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
}
var kubernetesWriteKubeconfigCmd = &cobra.Command{
Use: "write-kubeconfig",
RunE: kubernetesWriteKubeconfig,
Args: cobra.ExactArgs(1),
Short: "Write kubeconfig for a Kubernetes cluster.",
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
}
func init() {
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
}
func kubernetesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
if err != nil {
return err
}
if len(kcs) == 0 {
cmd.Println("No Kubernetes clusters available.")
return nil
}
cmd.Println("Available Kubernetes clusters:")
for _, k := range kcs {
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
}
return nil
}
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
kubeconfigPath, err := resolveKubeconfigPath(cmd)
if err != nil {
return err
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
clusterName := args[0]
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
if err != nil {
return err
}
if len(kcs) == 0 {
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
}
if len(kcs) > 1 {
return fmt.Errorf("too many Kubernetes clusters returned")
}
err = writeKubeconfig(kubeconfigPath, kcs[0])
if err != nil {
return err
}
return nil
}
type kubernetesCluster struct {
name string
url *url.URL
version string
}
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
httpClient := &http.Client{
Transport: transport,
}
resolver := net.Resolver{
// Required so both DNS records are returned.
// https://github.com/golang/go/issues/17093
PreferGo: true,
}
kcs := []kubernetesCluster{}
attempted := map[string]struct{}{}
for _, peer := range peers {
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
if err != nil {
return nil, err
}
for _, fqdn := range fqdns {
if _, ok := attempted[fqdn]; ok {
continue
}
attempted[fqdn] = struct{}{}
comps := strings.Split(fqdn, ".")
if len(comps) < 2 {
continue
}
if comps[1] != KubernetesDNSSuffix {
continue
}
if nameFilter != "" && nameFilter != comps[0] {
continue
}
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
if err != nil {
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
continue
}
kc := kubernetesCluster{
name: comps[0],
url: clusterURL,
version: clusterVersion,
}
if nameFilter != "" {
return []kubernetesCluster{kc}, nil
}
kcs = append(kcs, kc)
}
}
return kcs, nil
}
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
clusterURL, err := url.Parse("https://" + fqdn)
if err != nil {
return nil, "", err
}
versionURL, err := clusterURL.Parse("/version")
if err != nil {
return nil, "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
if err != nil {
return nil, "", err
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", err
}
versionData := map[string]string{}
err = json.Unmarshal(b, &versionData)
if err != nil {
return nil, "", err
}
version, ok := versionData["gitVersion"]
if !ok {
return nil, "", errors.New("no version found in response")
}
return clusterURL, version, nil
}
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
if cmd.Flags().Changed("kubeconfig") {
path, err := cmd.Flags().GetString("kubeconfig")
if err != nil {
return "", err
}
return path, nil
}
if env := os.Getenv("KUBECONFIG"); env != "" {
return env, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("could not determine home directory: %w", err)
}
return filepath.Join(home, ".kube", "config"), nil
}
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
b, err := os.ReadFile(kubeconfigPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
var cfg map[string]any
if err := yaml.Unmarshal(b, &cfg); err != nil {
return err
}
if cfg == nil {
cfg = map[string]any{
"apiVersion": "v1",
"kind": "Config",
}
}
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
"name": kc.name,
"cluster": map[string]any{
"server": kc.url.String(),
"insecure-skip-tls-verify": true,
},
})
cfg["users"] = appendWithName(cfg["users"], map[string]any{
"name": "netbird",
"user": map[string]any{
"token": "none",
},
})
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
"name": kc.name,
"context": map[string]any{
"cluster": kc.name,
"user": "netbird",
"namespace": "default",
},
})
cfg["current-context"] = kc.name
out, err := yaml.Marshal(cfg)
if err != nil {
return err
}
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
return err
}
return nil
}
func appendWithName(data any, add map[string]any) any {
if data == nil {
return []any{add}
}
v, ok := data.([]any)
if !ok {
return []any{add}
}
i := slices.IndexFunc(v, func(item any) bool {
m, ok := item.(map[string]any)
if !ok {
return false
}
return m["name"] == add["name"]
})
if i == -1 {
return append(v, add)
}
v[i] = add
return v
}

View File

@@ -1,120 +0,0 @@
package cmd
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)
func TestFingerprintClusters(t *testing.T) {
t.Parallel()
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//nolint: errcheck
w.Write([]byte(`{"gitVersion": "foobar"}`))
}))
defer srv.Close()
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
require.NoError(t, err)
require.Equal(t, srv.URL, clusterURL.String())
require.Equal(t, "foobar", clusterVersion)
}
func TestResolveKubeconfigPath(t *testing.T) {
home, err := os.UserHomeDir()
if err != nil {
t.Fatalf("could not determine home directory: %v", err)
}
defaultPath := filepath.Join(home, ".kube", "config")
path, err := resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, defaultPath, path)
flagPath := "flag-path"
cmd := &cobra.Command{}
cmd.Flags().String("kubeconfig", "", "")
err = cmd.Flags().Set("kubeconfig", flagPath)
require.NoError(t, err)
path, err = resolveKubeconfigPath(cmd)
require.NoError(t, err)
require.Equal(t, flagPath, path)
envPath := "env-path"
t.Setenv("KUBECONFIG", envPath)
path, err = resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, envPath, path)
}
func TestWriteKubeconfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existing string
}{
{
name: "empty file",
},
{
name: "existing content",
existing: `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://foobar.com
name: foo
current-context: test
kind: Config
users: []
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
kubeconfigPath := filepath.Join(t.TempDir(), "config")
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
require.NoError(t, err)
kc := kubernetesCluster{
name: "foo",
url: &url.URL{Scheme: "https", Host: "example.com"},
}
err = writeKubeconfig(kubeconfigPath, kc)
require.NoError(t, err)
b, err := os.ReadFile(kubeconfigPath)
require.NoError(t, err)
expected := `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://example.com
name: foo
contexts:
- context:
cluster: foo
namespace: default
user: netbird
name: foo
current-context: foo
kind: Config
users:
- name: netbird
user:
token: none
`
require.Equal(t, expected, string(b))
})
}
}

View File

@@ -169,11 +169,6 @@ func init() {
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
// kubernetes commands
rootCmd.AddCommand(kubernetesCmd)
kubernetesCmd.AddCommand(kubernetesListCmd)
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
// profile commands
profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd)

View File

@@ -279,10 +279,6 @@ func (c *Client) Start(startCtx context.Context) error {
select {
case <-startCtx.Done():
// Cancel the client context before stopping: Engine.Start blocks on the
// signal stream while holding the engine mutex and only unblocks on
// cancellation. Stopping first would deadlock on that mutex.
cancel()
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}

View File

@@ -1,168 +0,0 @@
package embed
import (
"context"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
// holding the engine mutex. When the Start context expires, the rollback path
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
func TestClientStartTimeoutRollback(t *testing.T) {
signalAddr := startBlackholeSignal(t)
mgmAddr := startManagement(t, signalAddr)
wgPort := 0
client, err := New(Options{
DeviceName: "embed-rollback-test",
SetupKey: testSetupKey,
ManagementURL: "http://" + mgmAddr,
WireguardPort: &wgPort,
})
require.NoError(t, err, "embed client creation must succeed")
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
startErr := make(chan error, 1)
go func() {
startErr <- client.Start(startCtx)
}()
select {
case err := <-startErr:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(60 * time.Second):
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
}
}
// startBlackholeSignal starts a gRPC server without the SignalExchange service
// registered. Connections succeed, but the signal stream can never be
// established, which keeps Engine.Start parked in WaitStreamConnected.
func startBlackholeSignal(t *testing.T) string {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}
func startManagement(t *testing.T, signalAddr string) string {
t.Helper()
cfg := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: t.TempDir(),
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
require.NoError(t, err)
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
permissionsManager := permissions.NewManager(testStore)
peersManager := peers.NewManager(testStore, permissionsManager)
jobManager := job.NewJobManager(nil, testStore, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
require.NoError(t, err)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
require.NoError(t, err)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
require.NoError(t, err)
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
require.NoError(t, err)
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}

View File

@@ -880,25 +880,62 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
return err
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
if err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
// Posture checks are bound to the network map presence:
// NetworkMap != nil, checks present -> apply the received checks
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
// leave the previously applied checks untouched
nm := update.GetNetworkMap()
if nm == nil {
return nil
err = e.updateSTUNs(wCfg.GetStuns())
if err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
err = e.handleRelayUpdate(wCfg.GetRelay())
if err != nil {
return err
}
err = e.handleFlowUpdate(wCfg.GetFlow())
if err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
e.persistSyncResponse(update)
nm := update.GetNetworkMap()
if nm == nil {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
// the whole Set so the store cannot be cleared (disabled / engine close)
// mid-call and have this write resurrect a file that was just removed.
e.syncRespMux.RLock()
if e.syncStore != nil {
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
} else {
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
}
e.syncRespMux.RUnlock()
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
@@ -910,64 +947,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// updateNetbirdConfig applies the management-provided NetBird configuration:
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
// which is the case for sync updates carrying only a network map.
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
if wCfg == nil {
return nil
}
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
return err
}
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
return nil
}
// persistSyncResponse stores the full sync response so it can be restored on the next
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
// engine close) mid-call and have this write resurrect a file that was just removed.
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock()
if e.syncStore == nil {
return
}
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
return
}
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
}
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil {
// when we receive token we expect valid address list too

View File

@@ -27,7 +27,7 @@ type Logger struct {
wgIfaceNetV6 netip.Prefix
dnsCollection atomic.Bool
exitNodeCollection atomic.Bool
Store types.AggregatingStore
Store types.Store
}
func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix) *Logger {
@@ -35,7 +35,7 @@ func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix)
statusRecorder: statusRecorder,
wgIfaceNet: wgIfaceIPNet,
wgIfaceNetV6: wgIfaceIPNetV6,
Store: store.NewAggregatingMemoryStore(),
Store: store.NewMemoryStore(),
}
}
@@ -125,10 +125,6 @@ func (l *Logger) stop() {
l.mux.Unlock()
}
func (l *Logger) ResetAggregationWindow() types.FlowEventAggregator {
return l.Store.ResetAggregationWindow()
}
func (l *Logger) GetEvents() []*types.Event {
return l.Store.GetEvents()
}

View File

@@ -9,14 +9,12 @@ import (
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
"github.com/netbirdio/netbird/client/internal/netflow/logger"
"github.com/netbirdio/netbird/client/internal/netflow/store"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/flow/client"
@@ -25,16 +23,14 @@ import (
// Manager handles netflow tracking and logging
type Manager struct {
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
receiverClient *client.GRPCClient
eventsWithoutAcks nftypes.Store
publicKey []byte
cancel context.CancelFunc
retryInterval time.Duration
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
receiverClient *client.GRPCClient
publicKey []byte
cancel context.CancelFunc
}
// NewManager creates a new netflow manager
@@ -52,11 +48,9 @@ func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *pee
}
return &Manager{
logger: flowLogger,
conntrack: ct,
publicKey: publicKey,
retryInterval: time.Second,
eventsWithoutAcks: store.NewMemoryStore(),
logger: flowLogger,
conntrack: ct,
publicKey: publicKey,
}
}
@@ -113,7 +107,7 @@ func (m *Manager) resetClient() error {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
m.shutdownWg.Add(3)
m.shutdownWg.Add(2)
go func() {
defer m.shutdownWg.Done()
m.receiveACKs(ctx, flowClient)
@@ -122,10 +116,6 @@ func (m *Manager) resetClient() error {
defer m.shutdownWg.Done()
m.startSender(ctx)
}()
go func() {
defer m.shutdownWg.Done()
m.startRetries(ctx)
}()
return nil
}
@@ -217,15 +207,13 @@ func (m *Manager) startSender(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
collectedEvents := m.logger.ResetAggregationWindow()
events := collectedEvents.GetAggregatedEvents()
events := m.logger.GetEvents()
for _, event := range events {
if err := m.send(event); err != nil {
log.Errorf("failed to send flow event to server: %v", err)
} else {
log.Tracef("sent flow event: %s", event.ID)
continue
}
m.eventsWithoutAcks.StoreEvent(event)
log.Tracef("sent flow event: %s", event.ID)
}
}
}
@@ -239,7 +227,7 @@ func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
return nil
}
log.Tracef("received flow event ack: %s", id)
m.eventsWithoutAcks.DeleteEvents([]uuid.UUID{id})
m.logger.DeleteEvents([]uuid.UUID{id})
return nil
})
@@ -248,41 +236,6 @@ func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
}
}
func (m *Manager) startRetries(ctx context.Context) {
ticker := time.NewTimer(m.retryInterval)
retryBackoff := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 1 * time.Second,
RandomizationFactor: 0.5,
Multiplier: 1.7,
MaxInterval: m.flowConfig.Interval / 2,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
for _, e := range m.eventsWithoutAcks.GetEvents() {
if e.Timestamp.Add(time.Second).After(time.Now()) {
// grace period on retries to avoid early retries
// do not retry if the event is less than 1 sec old
continue
}
if err := m.send(e); err != nil {
ticker = time.NewTimer(retryBackoff.NextBackOff()) //nolint:staticcheck,wastedassign
break
}
}
retryBackoff.Reset()
ticker = time.NewTimer(time.Second)
}
}
}
func (m *Manager) send(event *nftypes.Event) error {
m.mux.Lock()
client := m.receiverClient

View File

@@ -1,291 +0,0 @@
package netflow
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"testing"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/flow/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
type testServer struct {
proto.UnimplementedFlowServiceServer
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
handlerDone chan struct{} // signaled each time Events() exits
handlerStarted chan struct{} // signaled each time Events() begins
}
func newTestServer(t *testing.T) *testServer {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := &testServer{
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: listener.Addr().String(),
handlerDone: make(chan struct{}, 10),
handlerStarted: make(chan struct{}, 10),
}
proto.RegisterFlowServiceServer(s.grpcSrv, s)
go func() {
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
t.Logf("server error: %v", err)
}
}()
t.Cleanup(func() {
s.grpcSrv.Stop()
})
return s
}
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
defer func() {
select {
case s.handlerDone <- struct{}{}:
default:
}
}()
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
if err != nil {
return err
}
select {
case s.handlerStarted <- struct{}{}:
default:
}
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
go func() {
defer cancel()
for {
event, err := stream.Recv()
if err != nil {
return
}
if !event.IsInitiator {
select {
case s.events <- event:
case <-ctx.Done():
return
}
}
}
}()
for {
select {
case ack := <-s.acks:
if err := stream.Send(ack); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func TestSendEventReceiveAck(t *testing.T) {
_, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
server := newTestServer(t)
manager := createManager(t, server.addr, 60*time.Second) // set high to prevent retries in this test
defer manager.Close()
assert.Eventually(t, func() bool {
select {
case <-server.handlerStarted:
return true
default:
return false
}
}, 3*time.Second, 100*time.Millisecond)
event1 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.2"),
DestPort: 2345,
Protocol: 6,
}
manager.logger.StoreEvent(event1)
event2 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.1"),
DestPort: 1234,
Protocol: 6,
}
manager.logger.StoreEvent(event2)
// verify the server received logged events
serverSideEvents := make([]*proto.FlowEvent, 0)
assert.Eventually(t, func() bool {
select {
case event := <-server.events:
serverSideEvents = append(serverSideEvents, event)
if len(serverSideEvents) == 2 {
return true
}
default:
if len(serverSideEvents) == 2 {
return true
}
}
return false
}, 5*time.Second, 100*time.Millisecond)
serverSideFlowIds := make([]uuid.UUID, 0, 2)
slices.Values(serverSideEvents)(func(e *proto.FlowEvent) bool {
id, err := uuid.FromBytes(e.FlowFields.FlowId)
assert.NoError(t, err)
serverSideFlowIds = append(serverSideFlowIds, id)
return true
})
assert.ElementsMatch(t, []uuid.UUID{event1.FlowID, event2.FlowID}, serverSideFlowIds)
// verify the manager tracks un-acked events
unackedEvents := manager.eventsWithoutAcks.GetEvents()
assert.Len(t, unackedEvents, 2)
flowIds := make([]uuid.UUID, 0)
slices.Values(unackedEvents)(func(e *types.Event) bool {
flowIds = append(flowIds, e.FlowID)
return true
})
assert.ElementsMatch(t, flowIds, []uuid.UUID{event1.FlowID, event2.FlowID})
}
// verify handling of retries:
// - unacked events are retried
// - when acks arrive, events are removed from the un-acked event tracker
func TestRetryEvents(t *testing.T) {
_, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
server := newTestServer(t)
manager := createManager(t, server.addr, time.Second) // set low to start retries sooner
defer manager.Close()
assert.Eventually(t, func() bool {
select {
case <-server.handlerStarted:
return true
default:
return false
}
}, 3*time.Second, 100*time.Millisecond)
event1 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.2"),
DestPort: 2345,
Protocol: 6,
}
manager.logger.StoreEvent(event1)
event2 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.1"),
DestPort: 1234,
Protocol: 6,
}
manager.logger.StoreEvent(event2)
// verify the server received retries of logged events
serverSideEvents := make([]*proto.FlowEvent, 0)
func() {
c := time.After(2500 * time.Millisecond)
for {
select {
case event := <-server.events:
serverSideEvents = append(serverSideEvents, event)
case <-c:
return
}
}
}()
assert.True(t, len(serverSideEvents) > 2) // must see retries
uniqueServerSideEvents := make(map[uuid.UUID]*proto.FlowEvent)
slices.Values(serverSideEvents)(func(e *proto.FlowEvent) bool {
id, err := uuid.FromBytes(e.FlowFields.FlowId)
assert.NoError(t, err)
uniqueServerSideEvents[id] = e
return true
})
assert.Contains(t, uniqueServerSideEvents, event1.FlowID)
assert.Contains(t, uniqueServerSideEvents, event2.FlowID)
// ack events
server.acks <- &proto.FlowEventAck{EventId: uniqueServerSideEvents[event1.FlowID].EventId}
server.acks <- &proto.FlowEventAck{EventId: uniqueServerSideEvents[event2.FlowID].EventId}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
unackedEvents := manager.eventsWithoutAcks.GetEvents()
assert.Empty(c, unackedEvents)
}, 3*time.Second, 100*time.Millisecond)
}
func createManager(t *testing.T, serverAddr string, retryInterval time.Duration) *Manager {
t.Helper()
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: netip.MustParsePrefix("192.168.1.1/32"),
},
isUserspaceBind: true,
}
publicKey := []byte("test-public-key")
manager := NewManager(mockIFace, publicKey, nil)
manager.retryInterval = retryInterval
initialConfig := &types.FlowConfig{
Enabled: true,
URL: fmt.Sprintf("http://%s", serverAddr),
TokenPayload: "initial-payload",
TokenSignature: "initial-signature",
Interval: 500 * time.Millisecond,
}
err := manager.Update(initialConfig)
require.NoError(t, err)
return manager
}
func ipAddr(a string) netip.Addr {
addr, _ := netip.ParseAddr(a)
return addr
}

View File

@@ -1,190 +0,0 @@
package store
import (
"math/rand"
"net/netip"
"testing"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/stretchr/testify/assert"
)
var random = rand.New(rand.NewSource(time.Now().UnixNano()))
func TestFlowAggregation(t *testing.T) {
var protocols = []types.Protocol{types.ICMP, types.ICMPv6, types.TCP, types.UDP}
var tests = []struct {
description string
eventTypes []types.Type
}{
{
description: "start and stop",
eventTypes: []types.Type{types.TypeStart, types.TypeEnd},
},
{
description: "start and drop",
eventTypes: []types.Type{types.TypeStart, types.TypeDrop},
},
{
description: "start only",
eventTypes: []types.Type{types.TypeStart},
},
{
description: "drop only",
eventTypes: []types.Type{types.TypeDrop},
}}
for _, protocol := range protocols {
for _, tt := range tests {
t.Run(tt.description+" "+protocol.String(), func(t *testing.T) {
store := NewAggregatingMemoryStore()
allExpected := make([]*types.Event, 0)
for i := 0; i < 2; i++ {
inEvents, expected := generateEvents(tt.eventTypes, protocol, types.Ingress, 0)
for _, e := range inEvents {
store.StoreEvent(e)
}
allExpected = append(allExpected, expected)
}
events := store.GetAggregatedEvents()
assert.ElementsMatch(t, events, allExpected)
})
}
}
}
func TestIcmpEventAggregation(t *testing.T) {
var protocols = []types.Protocol{types.ICMP, types.ICMPv6}
var icmpTypes = []uint8{1, 2, 3}
var tests = []struct {
description string
eventTypes []types.Type
}{
{
description: "start and stop",
eventTypes: []types.Type{types.TypeStart, types.TypeEnd},
},
{
description: "start and drop",
eventTypes: []types.Type{types.TypeStart, types.TypeDrop},
},
{
description: "start only",
eventTypes: []types.Type{types.TypeStart},
},
{
description: "drop only",
eventTypes: []types.Type{types.TypeDrop},
}}
for _, protocol := range protocols {
for _, tt := range tests {
t.Run(tt.description+" "+protocol.String(), func(t *testing.T) {
store := NewAggregatingMemoryStore()
allExpected := make([]*types.Event, 0)
for _, icmpType := range icmpTypes {
events, expected := generateEvents(tt.eventTypes, protocol, types.Ingress, icmpType)
for _, e := range events {
store.StoreEvent(e)
}
allExpected = append(allExpected, expected)
}
aggregatedEvents := store.GetAggregatedEvents()
assert.Len(t, aggregatedEvents, len(allExpected))
assert.ElementsMatch(t, aggregatedEvents, allExpected)
})
}
}
}
func ipAddr(a string) netip.Addr {
addr, _ := netip.ParseAddr(a)
return addr
}
func generateEvents(eventTypes []types.Type, protocol types.Protocol, direction types.Direction, icmpType uint8) ([]*types.Event, *types.Event) {
var rxPackets, txPackets, rxBytes, txBytes uint64
inEvents := make([]*types.Event, 0)
ts := time.Now()
flowId := uuid.New()
srcIp := ipAddr("1.1.1.1")
srcPort := uint16(random.Uint32() >> 16)
dstIp := ipAddr("2.2.2.2")
dstPort := uint16(random.Uint32() >> 16)
for idx, eventType := range eventTypes {
e := &types.Event{
ID: uuid.New(),
Timestamp: ts.Add(time.Duration(idx) * time.Second),
EventFields: types.EventFields{
FlowID: flowId,
Type: eventType,
Protocol: protocol,
RuleID: []byte("rule-id-1"),
Direction: direction,
SourceIP: srcIp,
SourcePort: srcPort,
DestIP: dstIp,
DestPort: dstPort,
SourceResourceID: []byte("source-resource-id"),
DestResourceID: []byte("dest-resource-id"),
RxPackets: random.Uint64(),
TxPackets: random.Uint64(),
RxBytes: random.Uint64(),
TxBytes: random.Uint64(),
}}
rxBytes += e.RxBytes
txBytes += e.TxBytes
rxPackets += e.RxPackets
txPackets += e.TxPackets
inEvents = append(inEvents, e)
if protocol == types.ICMP || protocol == types.ICMPv6 {
e.ICMPType = icmpType
}
}
var start, end, drop uint64
for _, eventType := range eventTypes {
switch eventType {
case types.TypeStart:
start += 1
case types.TypeDrop:
drop += 1
case types.TypeEnd:
end += 1
}
}
aggregatedEvent := &types.Event{
ID: inEvents[0].ID,
Timestamp: inEvents[0].Timestamp,
EventFields: types.EventFields{
FlowID: flowId,
Type: inEvents[0].Type,
Protocol: inEvents[0].Protocol,
RuleID: []byte("rule-id-1"),
Direction: inEvents[0].Direction,
SourceIP: srcIp,
SourcePort: srcPort,
DestIP: dstIp,
DestPort: dstPort,
SourceResourceID: []byte("source-resource-id"),
DestResourceID: []byte("dest-resource-id"),
RxPackets: rxPackets,
TxPackets: txPackets,
RxBytes: rxBytes,
TxBytes: txBytes,
NumOfStarts: start,
NumOfEnds: end,
NumOfDrops: drop,
}}
if protocol == types.ICMP || protocol == types.ICMPv6 {
aggregatedEvent.ICMPType = icmpType
}
return inEvents, aggregatedEvent
}

View File

@@ -1,13 +1,10 @@
package store
import (
"maps"
"net/netip"
"slices"
"sync"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -22,10 +19,6 @@ type Memory struct {
events map[uuid.UUID]*types.Event
}
type AggregatingMemory struct {
Memory
}
func (m *Memory) StoreEvent(event *types.Event) {
m.mux.Lock()
defer m.mux.Unlock()
@@ -55,78 +48,3 @@ func (m *Memory) DeleteEvents(ids []uuid.UUID) {
delete(m.events, id)
}
}
func NewAggregatingMemoryStore() *AggregatingMemory {
return &AggregatingMemory{Memory{events: make(map[uuid.UUID]*types.Event)}}
}
func (am *AggregatingMemory) ResetAggregationWindow() types.FlowEventAggregator {
am.mux.Lock()
defer am.mux.Unlock()
toret := AggregatingMemory{Memory: Memory{events: am.events}}
am.events = make(map[uuid.UUID]*types.Event)
return &toret
}
type aggregationKey struct {
destAddr netip.Addr
destPort uint16
protocol uint8
icmpType uint8
unique int64 // used to prevent aggregation on non icmp/udp/tcp events
}
func (am *AggregatingMemory) GetAggregatedEvents() []*types.Event {
aggregated := make(map[aggregationKey]*types.Event)
for _, v := range am.events {
lookupKey := aggregationKey{destAddr: v.DestIP, destPort: v.DestPort, protocol: uint8(v.Protocol), icmpType: v.ICMPType}
if _, ok := aggregated[lookupKey]; !ok {
aggregated[lookupKey] = v.Clone()
event := aggregated[lookupKey]
if event.Protocol != types.ICMP && event.Protocol != types.ICMPv6 && event.Protocol != types.UDP && event.Protocol != types.TCP {
lookupKey.unique = time.Now().UnixNano() // to make the lookup key unique so we don't aggregate on it
continue
}
switch event.Type {
case types.TypeStart:
event.NumOfStarts += 1
case types.TypeDrop:
event.NumOfDrops += 1
case types.TypeEnd:
event.NumOfEnds += 1
}
continue
}
aggregatedEvent := aggregated[lookupKey]
if aggregatedEvent.Protocol != types.ICMP && aggregatedEvent.Protocol != types.ICMPv6 && aggregatedEvent.Protocol != types.UDP && aggregatedEvent.Protocol != types.TCP {
continue // we don't aggregate this type of events; shouldn't ever get here
}
// track the number of connections, duration?, open and close events?
aggregatedEvent.RxBytes += v.RxBytes
aggregatedEvent.RxPackets += v.RxPackets
aggregatedEvent.TxBytes += v.TxBytes
aggregatedEvent.TxPackets += v.TxPackets
switch v.Type {
case types.TypeStart:
aggregatedEvent.NumOfStarts += 1
case types.TypeDrop:
aggregatedEvent.NumOfDrops += 1
case types.TypeEnd:
aggregatedEvent.NumOfEnds += 1
}
if aggregatedEvent.Timestamp.Compare(v.Timestamp) > 0 {
aggregatedEvent.Timestamp = v.Timestamp
aggregatedEvent.ID = v.ID
aggregatedEvent.Type = v.Type
}
// do we aggregate icmp by code?
}
return slices.Collect(maps.Values(aggregated)) // could return an iterator instead here
}

View File

@@ -2,7 +2,6 @@ package types
import (
"net/netip"
"slices"
"strconv"
"time"
@@ -93,17 +92,6 @@ type EventFields struct {
TxPackets uint64
RxBytes uint64
TxBytes uint64
NumOfStarts uint64
NumOfEnds uint64
NumOfDrops uint64
}
func (e *Event) Clone() *Event {
toret := *e
toret.RuleID = slices.Clone(e.RuleID)
toret.SourceResourceID = slices.Clone(e.SourceResourceID)
toret.DestResourceID = slices.Clone(e.DestResourceID)
return &toret
}
type FlowConfig struct {
@@ -126,15 +114,13 @@ type FlowManager interface {
GetLogger() FlowLogger
}
type FlowEventAggregator interface {
ResetAggregationWindow() FlowEventAggregator
GetAggregatedEvents() []*Event
}
type FlowLogger interface {
ResetAggregationWindow() FlowEventAggregator
// StoreEvent stores a flow event
StoreEvent(flowEvent EventFields)
// GetEvents returns all stored events
GetEvents() []*Event
// DeleteEvents deletes events from the store
DeleteEvents([]uuid.UUID)
// Close closes the logger
Close()
// Enable enables the flow logger receiver
@@ -154,11 +140,6 @@ type Store interface {
Close()
}
type AggregatingStore interface {
FlowEventAggregator
Store
}
// ConnTracker defines the interface for connection tracking functionality
type ConnTracker interface {
// Start begins tracking connections by listening for conntrack events.

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v7.34.1
// protoc-gen-go v1.26.0
// protoc v3.21.9
// source: flow.proto
package proto
@@ -12,7 +12,6 @@ import (
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
@@ -126,24 +125,27 @@ func (Direction) EnumDescriptor() ([]byte, []int) {
}
type FlowEvent struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Unique client event identifier
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
// When the event occurred
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// Public key of the sending peer
PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
FlowFields *FlowFields `protobuf:"bytes,4,opt,name=flow_fields,json=flowFields,proto3" json:"flow_fields,omitempty"`
IsInitiator bool `protobuf:"varint,5,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
FlowFields *FlowFields `protobuf:"bytes,4,opt,name=flow_fields,json=flowFields,proto3" json:"flow_fields,omitempty"`
IsInitiator bool `protobuf:"varint,5,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
}
func (x *FlowEvent) Reset() {
*x = FlowEvent{}
mi := &file_flow_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *FlowEvent) String() string {
@@ -154,7 +156,7 @@ func (*FlowEvent) ProtoMessage() {}
func (x *FlowEvent) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[0]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -205,19 +207,22 @@ func (x *FlowEvent) GetIsInitiator() bool {
}
type FlowEventAck struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Unique client event identifier that has been ack'ed
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
IsInitiator bool `protobuf:"varint,2,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Unique client event identifier that has been ack'ed
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
IsInitiator bool `protobuf:"varint,2,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
}
func (x *FlowEventAck) Reset() {
*x = FlowEventAck{}
mi := &file_flow_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *FlowEventAck) String() string {
@@ -228,7 +233,7 @@ func (*FlowEventAck) ProtoMessage() {}
func (x *FlowEventAck) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[1]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -258,7 +263,10 @@ func (x *FlowEventAck) GetIsInitiator() bool {
}
type FlowFields struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Unique client flow session identifier
FlowId []byte `protobuf:"bytes,1,opt,name=flow_id,json=flowId,proto3" json:"flow_id,omitempty"`
// Flow type
@@ -275,7 +283,7 @@ type FlowFields struct {
DestIp []byte `protobuf:"bytes,7,opt,name=dest_ip,json=destIp,proto3" json:"dest_ip,omitempty"`
// Layer 4 -specific information
//
// Types that are valid to be assigned to ConnectionInfo:
// Types that are assignable to ConnectionInfo:
//
// *FlowFields_PortInfo
// *FlowFields_IcmpInfo
@@ -289,18 +297,15 @@ type FlowFields struct {
// Resource ID
SourceResourceId []byte `protobuf:"bytes,14,opt,name=source_resource_id,json=sourceResourceId,proto3" json:"source_resource_id,omitempty"`
DestResourceId []byte `protobuf:"bytes,15,opt,name=dest_resource_id,json=destResourceId,proto3" json:"dest_resource_id,omitempty"`
NumOfStarts uint64 `protobuf:"varint,16,opt,name=num_of_starts,json=numOfStarts,proto3" json:"num_of_starts,omitempty"`
NumOfEnds uint64 `protobuf:"varint,17,opt,name=num_of_ends,json=numOfEnds,proto3" json:"num_of_ends,omitempty"`
NumOfDrops uint64 `protobuf:"varint,18,opt,name=num_of_drops,json=numOfDrops,proto3" json:"num_of_drops,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *FlowFields) Reset() {
*x = FlowFields{}
mi := &file_flow_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *FlowFields) String() string {
@@ -311,7 +316,7 @@ func (*FlowFields) ProtoMessage() {}
func (x *FlowFields) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[2]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -375,27 +380,23 @@ func (x *FlowFields) GetDestIp() []byte {
return nil
}
func (x *FlowFields) GetConnectionInfo() isFlowFields_ConnectionInfo {
if x != nil {
return x.ConnectionInfo
func (m *FlowFields) GetConnectionInfo() isFlowFields_ConnectionInfo {
if m != nil {
return m.ConnectionInfo
}
return nil
}
func (x *FlowFields) GetPortInfo() *PortInfo {
if x != nil {
if x, ok := x.ConnectionInfo.(*FlowFields_PortInfo); ok {
return x.PortInfo
}
if x, ok := x.GetConnectionInfo().(*FlowFields_PortInfo); ok {
return x.PortInfo
}
return nil
}
func (x *FlowFields) GetIcmpInfo() *ICMPInfo {
if x != nil {
if x, ok := x.ConnectionInfo.(*FlowFields_IcmpInfo); ok {
return x.IcmpInfo
}
if x, ok := x.GetConnectionInfo().(*FlowFields_IcmpInfo); ok {
return x.IcmpInfo
}
return nil
}
@@ -442,27 +443,6 @@ func (x *FlowFields) GetDestResourceId() []byte {
return nil
}
func (x *FlowFields) GetNumOfStarts() uint64 {
if x != nil {
return x.NumOfStarts
}
return 0
}
func (x *FlowFields) GetNumOfEnds() uint64 {
if x != nil {
return x.NumOfEnds
}
return 0
}
func (x *FlowFields) GetNumOfDrops() uint64 {
if x != nil {
return x.NumOfDrops
}
return 0
}
type isFlowFields_ConnectionInfo interface {
isFlowFields_ConnectionInfo()
}
@@ -483,18 +463,21 @@ func (*FlowFields_IcmpInfo) isFlowFields_ConnectionInfo() {}
// TCP/UDP port information
type PortInfo struct {
state protoimpl.MessageState `protogen:"open.v1"`
SourcePort uint32 `protobuf:"varint,1,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"`
DestPort uint32 `protobuf:"varint,2,opt,name=dest_port,json=destPort,proto3" json:"dest_port,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
SourcePort uint32 `protobuf:"varint,1,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"`
DestPort uint32 `protobuf:"varint,2,opt,name=dest_port,json=destPort,proto3" json:"dest_port,omitempty"`
}
func (x *PortInfo) Reset() {
*x = PortInfo{}
mi := &file_flow_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PortInfo) String() string {
@@ -505,7 +488,7 @@ func (*PortInfo) ProtoMessage() {}
func (x *PortInfo) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[3]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -536,18 +519,21 @@ func (x *PortInfo) GetDestPort() uint32 {
// ICMP message information
type ICMPInfo struct {
state protoimpl.MessageState `protogen:"open.v1"`
IcmpType uint32 `protobuf:"varint,1,opt,name=icmp_type,json=icmpType,proto3" json:"icmp_type,omitempty"`
IcmpCode uint32 `protobuf:"varint,2,opt,name=icmp_code,json=icmpCode,proto3" json:"icmp_code,omitempty"`
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
IcmpType uint32 `protobuf:"varint,1,opt,name=icmp_type,json=icmpType,proto3" json:"icmp_type,omitempty"`
IcmpCode uint32 `protobuf:"varint,2,opt,name=icmp_code,json=icmpCode,proto3" json:"icmp_code,omitempty"`
}
func (x *ICMPInfo) Reset() {
*x = ICMPInfo{}
mi := &file_flow_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ICMPInfo) String() string {
@@ -558,7 +544,7 @@ func (*ICMPInfo) ProtoMessage() {}
func (x *ICMPInfo) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[4]
if x != nil {
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -589,83 +575,102 @@ func (x *ICMPInfo) GetIcmpCode() uint32 {
var File_flow_proto protoreflect.FileDescriptor
const file_flow_proto_rawDesc = "" +
"\n" +
"\n" +
"flow.proto\x12\x04flow\x1a\x1fgoogle/protobuf/timestamp.proto\"\xd4\x01\n" +
"\tFlowEvent\x12\x19\n" +
"\bevent_id\x18\x01 \x01(\fR\aeventId\x128\n" +
"\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n" +
"\n" +
"public_key\x18\x03 \x01(\fR\tpublicKey\x121\n" +
"\vflow_fields\x18\x04 \x01(\v2\x10.flow.FlowFieldsR\n" +
"flowFields\x12 \n" +
"\visInitiator\x18\x05 \x01(\bR\visInitiator\"K\n" +
"\fFlowEventAck\x12\x19\n" +
"\bevent_id\x18\x01 \x01(\fR\aeventId\x12 \n" +
"\visInitiator\x18\x02 \x01(\bR\visInitiator\"\x82\x05\n" +
"\n" +
"FlowFields\x12\x17\n" +
"\aflow_id\x18\x01 \x01(\fR\x06flowId\x12\x1e\n" +
"\x04type\x18\x02 \x01(\x0e2\n" +
".flow.TypeR\x04type\x12\x17\n" +
"\arule_id\x18\x03 \x01(\fR\x06ruleId\x12-\n" +
"\tdirection\x18\x04 \x01(\x0e2\x0f.flow.DirectionR\tdirection\x12\x1a\n" +
"\bprotocol\x18\x05 \x01(\rR\bprotocol\x12\x1b\n" +
"\tsource_ip\x18\x06 \x01(\fR\bsourceIp\x12\x17\n" +
"\adest_ip\x18\a \x01(\fR\x06destIp\x12-\n" +
"\tport_info\x18\b \x01(\v2\x0e.flow.PortInfoH\x00R\bportInfo\x12-\n" +
"\ticmp_info\x18\t \x01(\v2\x0e.flow.ICMPInfoH\x00R\bicmpInfo\x12\x1d\n" +
"\n" +
"rx_packets\x18\n" +
" \x01(\x04R\trxPackets\x12\x1d\n" +
"\n" +
"tx_packets\x18\v \x01(\x04R\ttxPackets\x12\x19\n" +
"\brx_bytes\x18\f \x01(\x04R\arxBytes\x12\x19\n" +
"\btx_bytes\x18\r \x01(\x04R\atxBytes\x12,\n" +
"\x12source_resource_id\x18\x0e \x01(\fR\x10sourceResourceId\x12(\n" +
"\x10dest_resource_id\x18\x0f \x01(\fR\x0edestResourceId\x12\"\n" +
"\rnum_of_starts\x18\x10 \x01(\x04R\vnumOfStarts\x12\x1e\n" +
"\vnum_of_ends\x18\x11 \x01(\x04R\tnumOfEnds\x12 \n" +
"\fnum_of_drops\x18\x12 \x01(\x04R\n" +
"numOfDropsB\x11\n" +
"\x0fconnection_info\"H\n" +
"\bPortInfo\x12\x1f\n" +
"\vsource_port\x18\x01 \x01(\rR\n" +
"sourcePort\x12\x1b\n" +
"\tdest_port\x18\x02 \x01(\rR\bdestPort\"D\n" +
"\bICMPInfo\x12\x1b\n" +
"\ticmp_type\x18\x01 \x01(\rR\bicmpType\x12\x1b\n" +
"\ticmp_code\x18\x02 \x01(\rR\bicmpCode*E\n" +
"\x04Type\x12\x10\n" +
"\fTYPE_UNKNOWN\x10\x00\x12\x0e\n" +
"\n" +
"TYPE_START\x10\x01\x12\f\n" +
"\bTYPE_END\x10\x02\x12\r\n" +
"\tTYPE_DROP\x10\x03*;\n" +
"\tDirection\x12\x15\n" +
"\x11DIRECTION_UNKNOWN\x10\x00\x12\v\n" +
"\aINGRESS\x10\x01\x12\n" +
"\n" +
"\x06EGRESS\x10\x022B\n" +
"\vFlowService\x123\n" +
"\x06Events\x12\x0f.flow.FlowEvent\x1a\x12.flow.FlowEventAck\"\x00(\x010\x01B\bZ\x06/protob\x06proto3"
var file_flow_proto_rawDesc = []byte{
0x0a, 0x0a, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x66, 0x6c,
0x6f, 0x77, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x22, 0xd4, 0x01, 0x0a, 0x09, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
0x74, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09,
0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75,
0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d,
0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63,
0x5f, 0x6b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c,
0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x31, 0x0a, 0x0b, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x66, 0x69,
0x65, 0x6c, 0x64, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x66, 0x6c, 0x6f,
0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x52, 0x0a, 0x66, 0x6c,
0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x73, 0x49, 0x6e,
0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x69,
0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22, 0x4b, 0x0a, 0x0c, 0x46, 0x6c,
0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76,
0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76,
0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x69,
0x61, 0x74, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x69, 0x73, 0x49, 0x6e,
0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22, 0x9c, 0x04, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77,
0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12,
0x1e, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e,
0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12,
0x17, 0x0a, 0x07, 0x72, 0x75, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x06, 0x72, 0x75, 0x6c, 0x65, 0x49, 0x64, 0x12, 0x2d, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65,
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0f, 0x2e, 0x66, 0x6c,
0x6f, 0x77, 0x2e, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x64, 0x69,
0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x63, 0x6f, 0x6c, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70,
0x18, 0x06, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70,
0x12, 0x17, 0x0a, 0x07, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x70, 0x18, 0x07, 0x20, 0x01, 0x28,
0x0c, 0x52, 0x06, 0x64, 0x65, 0x73, 0x74, 0x49, 0x70, 0x12, 0x2d, 0x0a, 0x09, 0x70, 0x6f, 0x72,
0x74, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66,
0x6c, 0x6f, 0x77, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08,
0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x2d, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70,
0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66, 0x6c,
0x6f, 0x77, 0x2e, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08, 0x69,
0x63, 0x6d, 0x70, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x78, 0x5f, 0x70, 0x61,
0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x72, 0x78, 0x50,
0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x78, 0x5f, 0x70, 0x61, 0x63,
0x6b, 0x65, 0x74, 0x73, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x74, 0x78, 0x50, 0x61,
0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65,
0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73,
0x12, 0x19, 0x0a, 0x08, 0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01,
0x28, 0x04, 0x52, 0x07, 0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x73,
0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69,
0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52,
0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x73,
0x74, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63,
0x65, 0x49, 0x64, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f,
0x6e, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e,
0x66, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72,
0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50,
0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74,
0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74, 0x50, 0x6f, 0x72, 0x74,
0x22, 0x44, 0x0a, 0x08, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1b, 0x0a, 0x09,
0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52,
0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d,
0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63,
0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10,
0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00,
0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x52, 0x54, 0x10, 0x01,
0x12, 0x0c, 0x0a, 0x08, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44, 0x10, 0x02, 0x12, 0x0d,
0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x03, 0x2a, 0x3b, 0x0a,
0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x11, 0x44, 0x49,
0x52, 0x45, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
0x00, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x01, 0x12, 0x0a,
0x0a, 0x06, 0x45, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42, 0x0a, 0x0b, 0x46, 0x6c,
0x6f, 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x06, 0x45, 0x76, 0x65,
0x6e, 0x74, 0x73, 0x12, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45,
0x76, 0x65, 0x6e, 0x74, 0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77,
0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08,
0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_flow_proto_rawDescOnce sync.Once
file_flow_proto_rawDescData []byte
file_flow_proto_rawDescData = file_flow_proto_rawDesc
)
func file_flow_proto_rawDescGZIP() []byte {
file_flow_proto_rawDescOnce.Do(func() {
file_flow_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_flow_proto_rawDesc), len(file_flow_proto_rawDesc)))
file_flow_proto_rawDescData = protoimpl.X.CompressGZIP(file_flow_proto_rawDescData)
})
return file_flow_proto_rawDescData
}
var file_flow_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_flow_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_flow_proto_goTypes = []any{
var file_flow_proto_goTypes = []interface{}{
(Type)(0), // 0: flow.Type
(Direction)(0), // 1: flow.Direction
(*FlowEvent)(nil), // 2: flow.FlowEvent
@@ -696,7 +701,69 @@ func file_flow_proto_init() {
if File_flow_proto != nil {
return
}
file_flow_proto_msgTypes[2].OneofWrappers = []any{
if !protoimpl.UnsafeEnabled {
file_flow_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*FlowEvent); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_flow_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*FlowEventAck); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_flow_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*FlowFields); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_flow_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PortInfo); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_flow_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ICMPInfo); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_flow_proto_msgTypes[2].OneofWrappers = []interface{}{
(*FlowFields_PortInfo)(nil),
(*FlowFields_IcmpInfo)(nil),
}
@@ -704,7 +771,7 @@ func file_flow_proto_init() {
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_flow_proto_rawDesc), len(file_flow_proto_rawDesc)),
RawDescriptor: file_flow_proto_rawDesc,
NumEnums: 2,
NumMessages: 5,
NumExtensions: 0,
@@ -716,6 +783,7 @@ func file_flow_proto_init() {
MessageInfos: file_flow_proto_msgTypes,
}.Build()
File_flow_proto = out.File
file_flow_proto_rawDesc = nil
file_flow_proto_goTypes = nil
file_flow_proto_depIdxs = nil
}

View File

@@ -75,9 +75,6 @@ message FlowFields {
bytes source_resource_id = 14;
bytes dest_resource_id = 15;
uint64 num_of_starts = 16;
uint64 num_of_ends = 17;
uint64 num_of_drops = 18;
}
// Flow event types

View File

@@ -1,8 +1,4 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.1
// - protoc v7.34.1
// source: flow.proto
package proto
@@ -15,19 +11,15 @@ import (
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
FlowService_Events_FullMethodName = "/flow.FlowService/Events"
)
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// FlowServiceClient is the client API for FlowService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type FlowServiceClient interface {
// Client to receiver streams of events and acknowledgements
Events(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[FlowEvent, FlowEventAck], error)
Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error)
}
type flowServiceClient struct {
@@ -38,40 +30,54 @@ func NewFlowServiceClient(cc grpc.ClientConnInterface) FlowServiceClient {
return &flowServiceClient{cc}
}
func (c *flowServiceClient) Events(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[FlowEvent, FlowEventAck], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &FlowService_ServiceDesc.Streams[0], FlowService_Events_FullMethodName, cOpts...)
func (c *flowServiceClient) Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error) {
stream, err := c.cc.NewStream(ctx, &FlowService_ServiceDesc.Streams[0], "/flow.FlowService/Events", opts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[FlowEvent, FlowEventAck]{ClientStream: stream}
x := &flowServiceEventsClient{stream}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type FlowService_EventsClient = grpc.BidiStreamingClient[FlowEvent, FlowEventAck]
type FlowService_EventsClient interface {
Send(*FlowEvent) error
Recv() (*FlowEventAck, error)
grpc.ClientStream
}
type flowServiceEventsClient struct {
grpc.ClientStream
}
func (x *flowServiceEventsClient) Send(m *FlowEvent) error {
return x.ClientStream.SendMsg(m)
}
func (x *flowServiceEventsClient) Recv() (*FlowEventAck, error) {
m := new(FlowEventAck)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// FlowServiceServer is the server API for FlowService service.
// All implementations must embed UnimplementedFlowServiceServer
// for forward compatibility.
// for forward compatibility
type FlowServiceServer interface {
// Client to receiver streams of events and acknowledgements
Events(grpc.BidiStreamingServer[FlowEvent, FlowEventAck]) error
Events(FlowService_EventsServer) error
mustEmbedUnimplementedFlowServiceServer()
}
// UnimplementedFlowServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedFlowServiceServer struct{}
// UnimplementedFlowServiceServer must be embedded to have forward compatible implementations.
type UnimplementedFlowServiceServer struct {
}
func (UnimplementedFlowServiceServer) Events(grpc.BidiStreamingServer[FlowEvent, FlowEventAck]) error {
return status.Error(codes.Unimplemented, "method Events not implemented")
func (UnimplementedFlowServiceServer) Events(FlowService_EventsServer) error {
return status.Errorf(codes.Unimplemented, "method Events not implemented")
}
func (UnimplementedFlowServiceServer) mustEmbedUnimplementedFlowServiceServer() {}
func (UnimplementedFlowServiceServer) testEmbeddedByValue() {}
// UnsafeFlowServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to FlowServiceServer will
@@ -81,22 +87,34 @@ type UnsafeFlowServiceServer interface {
}
func RegisterFlowServiceServer(s grpc.ServiceRegistrar, srv FlowServiceServer) {
// If the following call panics, it indicates UnimplementedFlowServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&FlowService_ServiceDesc, srv)
}
func _FlowService_Events_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(FlowServiceServer).Events(&grpc.GenericServerStream[FlowEvent, FlowEventAck]{ServerStream: stream})
return srv.(FlowServiceServer).Events(&flowServiceEventsServer{stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type FlowService_EventsServer = grpc.BidiStreamingServer[FlowEvent, FlowEventAck]
type FlowService_EventsServer interface {
Send(*FlowEventAck) error
Recv() (*FlowEvent, error)
grpc.ServerStream
}
type flowServiceEventsServer struct {
grpc.ServerStream
}
func (x *flowServiceEventsServer) Send(m *FlowEventAck) error {
return x.ServerStream.SendMsg(m)
}
func (x *flowServiceEventsServer) Recv() (*FlowEvent, error) {
m := new(FlowEvent)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// FlowService_ServiceDesc is the grpc.ServiceDesc for FlowService service.
// It's only intended for direct use with grpc.RegisterService,

View File

@@ -10,9 +10,8 @@ fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
echo "$script_path"
cd "$script_path"
#go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
#go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./flow.proto --go_out=../ --go-grpc_out=../
cd "$old_pwd"

6
go.mod
View File

@@ -2,8 +2,6 @@ module github.com/netbirdio/netbird
go 1.25.5
toolchain go1.25.11
require (
cunicu.li/go-rosenpass v0.5.42
github.com/cenkalti/backoff/v4 v4.3.0
@@ -56,7 +54,6 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.4
github.com/goccy/go-yaml v1.18.0
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang/mock v1.6.0
@@ -214,9 +211,10 @@ require (
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/go-tpm v0.9.8 // indirect
github.com/google/s2a-go v0.1.9 // indirect

4
go.sum
View File

@@ -275,8 +275,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=

View File

@@ -488,195 +488,6 @@ func TestUpdate_AllowsPortChange(t *testing.T) {
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
}
func TestUpdate_PreservesPortWhenCustomPortsNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "update must not be rejected by the custom-port capability check")
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved on unsupported cluster")
}
func TestUpdate_PreservesPortWhenCustomPortsUnknown(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, nil)
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "update must not be rejected when cluster capability is unknown")
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when capability is unknown")
}
func TestUpdate_RejectsPortChangeWhenCustomPortsNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 54321,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.Error(t, err, "explicit port change on update must be rejected on unsupported clusters")
assert.Contains(t, err.Error(), "custom ports not supported on target cluster")
}
func TestUpdate_TLSPortChangeAllowedWhenNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tls-svc", "tls", "app.example.com", testCluster, 443)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tls-svc",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "TLS port change uses SNI routing and is exempt from the custom-port check")
assert.Equal(t, uint16(9999), updated.ListenPort, "TLS port change should be applied")
}
func TestValidateL4PortDiffOnClusterDiff(t *testing.T) {
tests := []struct {
name string
mode string
customPorts *bool
newPort uint16
oldPort uint16
wantErr bool
}{
{"tcp port change unsupported", "tcp", boolPtr(false), 54321, 12345, true},
{"tcp port change unknown capability", "tcp", nil, 54321, 12345, true},
{"udp port change unsupported", "udp", boolPtr(false), 54321, 12345, true},
{"tcp first port assignment unsupported", "tcp", boolPtr(false), 54321, 0, true},
{"tcp port change supported", "tcp", boolPtr(true), 54321, 12345, false},
{"tcp port unchanged unsupported", "tcp", boolPtr(false), 12345, 12345, false},
{"tcp zero port unsupported", "tcp", boolPtr(false), 0, 12345, false},
{"tls port change unsupported", "tls", boolPtr(false), 9999, 443, false},
{"http mode ignored", "http", boolPtr(false), 54321, 12345, false},
{"empty mode ignored", "", boolPtr(false), 54321, 12345, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
newSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.newPort, ProxyCluster: testCluster}
oldSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.oldPort, ProxyCluster: testCluster}
err := validateL4PortDiffOnClusterDiff(tc.customPorts, newSvc, oldSvc)
if tc.wantErr {
assert.Error(t, err, "port diff should be rejected for %s", tc.name)
} else {
assert.NoError(t, err, "port diff should be allowed for %s", tc.name)
}
})
}
}
func TestUpdate_PortConflictRejected(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "tcp-a", "tcp", "tcp-a."+testCluster, testCluster, 5432)
svcB := seedService(t, testStore, "tcp-b", "tcp", "tcp-b."+testCluster, testCluster, 6543)
updated := &rpservice.Service{
ID: svcB.ID,
AccountID: testAccountID,
Name: "tcp-b",
Mode: "tcp",
Domain: "tcp-b." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.Error(t, err, "updating to a port held by another service should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestUpdate_AutoAssignsWhenNoPort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 0)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.True(t, updated.ListenPort >= autoAssignPortMin && updated.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range [%d, %d]", updated.ListenPort, autoAssignPortMin, autoAssignPortMax)
assert.True(t, updated.PortAutoAssigned, "PortAutoAssigned should be set when update triggers auto-assignment")
}
func TestCreateServiceFromPeer_TCP(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()

View File

@@ -338,7 +338,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
}
}
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
return err
}
@@ -367,11 +367,11 @@ func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service)
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool, serviceUpdate bool) error {
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && !serviceUpdate && (customPorts == nil || !*customPorts) {
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
}
@@ -465,7 +465,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
return err
}
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
return err
}
@@ -651,22 +651,12 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
// if the service is being updated, and we decide in the future to allow mode update,
// we should reconsider the currently assigned port if not 0 for clusters that don't support custom ports
if err := validateL4PortDiffOnClusterDiff(customPorts, service, existingService); err != nil {
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, service, customPorts, true); err != nil {
return err
}
// we can try carrying the previous service port into a new cluster, if this becomes a problem for multiple users,
// we should reconsider adding another check
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -674,21 +664,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
return nil
}
// validateL4PortDiffOnClusterDiff checks if custom L4 ports are configured and validates port changes across clusters.
// It ensures no port changes if custom ports are unsupported for a given cluster and protocol mode.
// Returns an error if validation fails, otherwise returns nil.
func validateL4PortDiffOnClusterDiff(customPorts *bool, newSVC, oldSVC *service.Service) error {
if !service.IsPortBasedProtocol(newSVC.Mode) || (customPorts != nil && *customPorts) {
return nil
}
if newSVC.ListenPort != 0 && newSVC.ListenPort != oldSVC.ListenPort {
return status.Errorf(status.InvalidArgument, "custom ports not supported on target cluster %s", newSVC.ProxyCluster)
}
return nil
}
// handleDomainChange validates the new domain is free inside the transaction
// and applies the pre-resolved cluster (computed outside the tx by
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks

View File

@@ -8,8 +8,6 @@ import (
"strings"
"time"
"github.com/hashicorp/go-version"
nbversion "github.com/netbirdio/netbird/version"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -30,23 +28,6 @@ import (
"github.com/netbirdio/netbird/shared/sshauth"
)
const (
// deprecatedRemotePeersVersion is the version of Netbird that introduced the NetworkMap.RemotePeers field, deprecated in favor of RemotePeers.
deprecatedRemotePeersVersion = "0.29.3"
)
// precomputedDeprecatedRemotePeersConstraint is the parsed ">= 0.29.3" constraint,
// built once at init since the bound is a compile-time constant.
var precomputedDeprecatedRemotePeersConstraint version.Constraints
func init() {
constraint, err := version.NewConstraint(">= " + deprecatedRemotePeersVersion)
if err != nil {
panic("parse deprecated remote peers version constraint: " + err.Error())
}
precomputedDeprecatedRemotePeersConstraint = constraint
}
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
@@ -174,11 +155,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
if !shouldSkipSendingDeprecatedRemotePeers(peer.Meta.WtVersion) {
response.RemotePeers = remotePeers
}
response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
@@ -269,19 +246,6 @@ func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]m
return hashedUsers, machineUsers
}
func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
if nbversion.IsDevelopmentVersion(peerVersion) {
return true
}
peerNBVersion, err := version.NewVersion(peerVersion)
if err != nil {
return false
}
return precomputedDeprecatedRemotePeersConstraint.Check(peerNBVersion)
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
@@ -399,6 +363,7 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {

View File

@@ -202,42 +202,6 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
}
}
// TestShouldSkipSendingDeprecatedRemotePeers covers the version gate that
// stops populating the deprecated top-level SyncResponse.RemotePeers field for
// peers new enough to read RemotePeers off the NetworkMap. Development builds
// are treated as latest and skip the field. The gate otherwise fails safe: a
// release version older than the boundary, or one that can't be parsed (empty,
// garbage, prereleases of the boundary) still receives the deprecated field so
// older/unknown clients keep working.
func TestShouldSkipSendingDeprecatedRemotePeers(t *testing.T) {
tests := []struct {
name string
peerVersion string
wantSkip bool
}{
{"exact boundary skips", "0.29.3", true},
{"newer patch skips", "0.29.4", true},
{"newer minor skips", "0.30.0", true},
{"newer major skips", "1.0.0", true},
{"v-prefixed newer skips", "v0.30.0", true},
{"development build skips", "development", true},
{"development build with commit skips", "development-abc123def456-dirty", true},
{"older patch keeps field", "0.29.2", false},
{"older minor keeps field", "0.28.0", false},
{"prerelease of boundary keeps field", "0.29.3-SNAPSHOT", false},
{"tagged dev prerelease keeps field", "v0.31.1-dev", false},
{"empty version keeps field", "", false},
{"garbage version keeps field", "not-a-version", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := shouldSkipSendingDeprecatedRemotePeers(tc.peerVersion)
assert.Equal(t, tc.wantSkip, got, "skip decision for peer version %q", tc.peerVersion)
})
}
}
// TestEncodeSessionExpiresAt pins the wire encoding the client's
// applySessionDeadline depends on:
//

View File

@@ -24,7 +24,6 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"github.com/pires/go-proxyproto"
prometheus2 "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -76,30 +75,29 @@ type portRouter struct {
}
type Server struct {
ctx context.Context
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
staticCertWatcher *certwatch.Watcher
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
meter *proxymetrics.Metrics
accessLog *accesslog.Logger
mainRouter *nbtcp.Router
mainPort uint16
udpMu sync.Mutex
udpRelays map[types.ServiceID]*udprelay.Relay
udpRelayWg sync.WaitGroup
portMu sync.RWMutex
portRouters map[uint16]*portRouter
svcPorts map[types.ServiceID][]uint16
lastMappings map[types.ServiceID]*proto.ProxyMapping
portRouterWg sync.WaitGroup
ctx context.Context
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
meter *proxymetrics.Metrics
accessLog *accesslog.Logger
mainRouter *nbtcp.Router
mainPort uint16
udpMu sync.Mutex
udpRelays map[types.ServiceID]*udprelay.Relay
udpRelayWg sync.WaitGroup
portMu sync.RWMutex
portRouters map[uint16]*portRouter
svcPorts map[types.ServiceID][]uint16
lastMappings map[types.ServiceID]*proto.ProxyMapping
portRouterWg sync.WaitGroup
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
// so they can be closed during graceful shutdown, since http.Server.Shutdown
@@ -616,7 +614,7 @@ func (s *Server) initDefaults() {
// If no ID is set then one can be generated.
if s.ID == "" {
s.ID = fmt.Sprintf("netbird-proxy-%s", uuid.NewString())
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
}
// Fallback version option in case it is not set.
if s.Version == "" {
@@ -794,7 +792,6 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
}
go certWatcher.Watch(ctx)
s.staticCertWatcher = certWatcher
tlsConfig.GetCertificate = certWatcher.GetCertificate
return tlsConfig, nil
}
@@ -1626,8 +1623,6 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
var wildcardHit bool
if s.acme != nil {
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
} else {
wildcardHit = s.staticCertCovers(d)
}
httpRoute := nbtcp.Route{
Type: nbtcp.RouteHTTP,
@@ -1652,26 +1647,6 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
return nil
}
// staticCertCovers reports whether the static certificate loaded when ACME is
// disabled covers the given domain, making it certificate-ready immediately —
// the equivalent of a wildcard hit in the ACME path. Domains the certificate
// does not cover are logged: clients connecting to them will get TLS errors.
func (s *Server) staticCertCovers(d domain.Domain) bool {
if s.staticCertWatcher == nil {
return false
}
leaf := s.staticCertWatcher.Leaf()
if leaf == nil {
return false
}
name := d.PunycodeString()
if err := leaf.VerifyHostname(name); err != nil {
s.Logger.Warnf("static certificate (SANs %v) does not cover domain %q: %v", leaf.DNSNames, name, err)
return false
}
return true
}
// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port.
func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
svcID := types.ServiceID(mapping.GetId())

View File

@@ -1,89 +0,0 @@
package proxy
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/shared/management/domain"
)
func generateCertWithSANs(t *testing.T, dnsNames []string) (certPEM, keyPEM []byte) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: dnsNames[0]},
DNSNames: dnsNames,
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
keyDER, err := x509.MarshalECPrivateKey(key)
require.NoError(t, err)
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return certPEM, keyPEM
}
func newStaticWatcher(t *testing.T, dnsNames []string) *certwatch.Watcher {
t.Helper()
dir := t.TempDir()
certPEM, keyPEM := generateCertWithSANs(t, dnsNames)
certPath := filepath.Join(dir, "tls.crt")
keyPath := filepath.Join(dir, "tls.key")
require.NoError(t, os.WriteFile(certPath, certPEM, 0o600))
require.NoError(t, os.WriteFile(keyPath, keyPEM, 0o600))
w, err := certwatch.NewWatcher(certPath, keyPath, quietLifecycleLogger())
require.NoError(t, err)
return w
}
func TestStaticCertCovers(t *testing.T) {
s := &Server{
Logger: quietLifecycleLogger(),
staticCertWatcher: newStaticWatcher(t, []string{"*.p.example.com", "exact.example.com"}),
}
cases := []struct {
domain string
covered bool
}{
{"svc.p.example.com", true},
{"exact.example.com", true},
{"a.b.p.example.com", false}, // wildcard does not span labels
{"p.example.com", false},
{"other.example.com", false},
}
for _, tc := range cases {
t.Run(tc.domain, func(t *testing.T) {
assert.Equal(t, tc.covered, s.staticCertCovers(domain.Domain(tc.domain)))
})
}
}
func TestStaticCertCoversNoWatcher(t *testing.T) {
s := &Server{Logger: quietLifecycleLogger()}
assert.False(t, s.staticCertCovers(domain.Domain("svc.p.example.com")))
}

View File

@@ -322,21 +322,15 @@ func TestClient_Sync(t *testing.T) {
if resp.GetNetbirdConfig() == nil {
t.Error("expecting non nil NetbirdConfig got nil")
}
// we test network map peers from 0.29.3 and dev builds
if len(resp.GetRemotePeers()) != 0 {
t.Error("expecting top-level RemotePeers to be empty for v0.29.3+ clients")
}
networkMap := resp.GetNetworkMap()
if len(networkMap.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(networkMap.GetRemotePeers()))
if len(resp.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
return
}
if networkMap.GetRemotePeersIsEmpty() {
if resp.GetRemotePeersIsEmpty() == true {
t.Error("expecting RemotePeers property to be false, got true")
}
if networkMap.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), networkMap.GetRemotePeers()[0].GetWgPubKey())
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for test to finish")