mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-19 15:19:55 +00:00
Merge remote-tracking branch 'origin/main' into feat/byod-proxy
# Conflicts: # management/internals/modules/reverseproxy/proxy/manager.go # management/internals/modules/reverseproxy/proxy/manager/manager.go # management/internals/modules/reverseproxy/proxy/manager_mock.go # management/internals/shared/grpc/proxy.go # management/server/store/sql_store.go # management/server/store/store.go # management/server/store/store_mock.go # proxy/management_integration_test.go
This commit is contained in:
@@ -2,7 +2,12 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -99,6 +104,27 @@ var debugStopCmd = &cobra.Command{
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
var debugCaptureCmd = &cobra.Command{
|
||||
Use: "capture <account-id> [filter expression]",
|
||||
Short: "Capture packets on a client's WireGuard interface",
|
||||
Long: `Captures decrypted packets flowing through a client's WireGuard interface.
|
||||
|
||||
Default output is human-readable text. Use --pcap or --output for pcap binary.
|
||||
Filter arguments after the account ID use BPF-like syntax.
|
||||
|
||||
Examples:
|
||||
netbird-proxy debug capture <account-id>
|
||||
netbird-proxy debug capture <account-id> --duration 1m host 10.0.0.1
|
||||
netbird-proxy debug capture <account-id> host 10.0.0.1 and tcp port 443
|
||||
netbird-proxy debug capture <account-id> not port 22
|
||||
netbird-proxy debug capture <account-id> -o capture.pcap
|
||||
netbird-proxy debug capture <account-id> --pcap | tcpdump -r - -n
|
||||
netbird-proxy debug capture <account-id> --pcap | tshark -r -`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: runDebugCapture,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
func init() {
|
||||
debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address")
|
||||
debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format")
|
||||
@@ -110,6 +136,12 @@ func init() {
|
||||
|
||||
debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)")
|
||||
|
||||
debugCaptureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = server default)")
|
||||
debugCaptureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
|
||||
debugCaptureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length (text mode)")
|
||||
debugCaptureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (text mode)")
|
||||
debugCaptureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
|
||||
|
||||
debugCmd.AddCommand(debugHealthCmd)
|
||||
debugCmd.AddCommand(debugClientsCmd)
|
||||
debugCmd.AddCommand(debugStatusCmd)
|
||||
@@ -119,6 +151,7 @@ func init() {
|
||||
debugCmd.AddCommand(debugLogCmd)
|
||||
debugCmd.AddCommand(debugStartCmd)
|
||||
debugCmd.AddCommand(debugStopCmd)
|
||||
debugCmd.AddCommand(debugCaptureCmd)
|
||||
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
}
|
||||
@@ -171,3 +204,84 @@ func runDebugStart(cmd *cobra.Command, args []string) error {
|
||||
func runDebugStop(cmd *cobra.Command, args []string) error {
|
||||
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
|
||||
}
|
||||
|
||||
func runDebugCapture(cmd *cobra.Command, args []string) error {
|
||||
duration, _ := cmd.Flags().GetDuration("duration")
|
||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
||||
verbose, _ := cmd.Flags().GetBool("verbose")
|
||||
ascii, _ := cmd.Flags().GetBool("ascii")
|
||||
outPath, _ := cmd.Flags().GetString("output")
|
||||
|
||||
// Default to text. Use pcap when --pcap is set or --output is given.
|
||||
wantText := !forcePcap && outPath == ""
|
||||
|
||||
var filterExpr string
|
||||
if len(args) > 1 {
|
||||
filterExpr = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
out, cleanup, err := captureOutputWriter(cmd, outPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if wantText {
|
||||
cmd.PrintErrln("Capturing packets... Press Ctrl+C to stop.")
|
||||
} else {
|
||||
cmd.PrintErrln("Capturing packets (pcap)... Press Ctrl+C to stop.")
|
||||
}
|
||||
|
||||
var durationStr string
|
||||
if duration > 0 {
|
||||
durationStr = duration.String()
|
||||
}
|
||||
|
||||
err = getDebugClient(cmd).Capture(ctx, debug.CaptureOptions{
|
||||
AccountID: args[0],
|
||||
Duration: durationStr,
|
||||
FilterExpr: filterExpr,
|
||||
Text: wantText,
|
||||
Verbose: verbose,
|
||||
ASCII: ascii,
|
||||
Output: out,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.PrintErrln("\nCapture finished.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// captureOutputWriter returns the writer and cleanup function for capture output.
|
||||
func captureOutputWriter(cmd *cobra.Command, outPath string) (out *os.File, cleanup func(), err error) {
|
||||
if outPath != "" {
|
||||
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create output file: %w", err)
|
||||
}
|
||||
tmpPath := f.Name()
|
||||
return f, func() {
|
||||
if err := f.Close(); err != nil {
|
||||
cmd.PrintErrf("close output file: %v\n", err)
|
||||
}
|
||||
if fi, err := os.Stat(tmpPath); err == nil && fi.Size() > 0 {
|
||||
if err := os.Rename(tmpPath, outPath); err != nil {
|
||||
cmd.PrintErrf("rename output file: %v\n", err)
|
||||
} else {
|
||||
cmd.PrintErrf("Wrote %s\n", outPath)
|
||||
}
|
||||
} else {
|
||||
os.Remove(tmpPath)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
return os.Stdout, func() {
|
||||
// no cleanup needed for stdout
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -310,6 +310,76 @@ func (c *Client) printError(data map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
// CaptureOptions configures a capture request.
|
||||
type CaptureOptions struct {
|
||||
AccountID string
|
||||
Duration string
|
||||
FilterExpr string
|
||||
Text bool
|
||||
Verbose bool
|
||||
ASCII bool
|
||||
Output io.Writer
|
||||
}
|
||||
|
||||
// Capture streams a packet capture from the debug endpoint. The response body
|
||||
// (pcap or text) is written directly to opts.Output until the server closes the
|
||||
// connection or the context is cancelled.
|
||||
func (c *Client) Capture(ctx context.Context, opts CaptureOptions) error {
|
||||
if opts.AccountID == "" {
|
||||
return fmt.Errorf("account ID is required")
|
||||
}
|
||||
if opts.Output == nil {
|
||||
return fmt.Errorf("output writer is required")
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
if opts.Duration != "" {
|
||||
params.Set("duration", opts.Duration)
|
||||
}
|
||||
if opts.FilterExpr != "" {
|
||||
params.Set("filter", opts.FilterExpr)
|
||||
}
|
||||
if opts.Text {
|
||||
params.Set("format", "text")
|
||||
}
|
||||
if opts.Verbose {
|
||||
params.Set("verbose", "true")
|
||||
}
|
||||
if opts.ASCII {
|
||||
params.Set("ascii", "true")
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/capture", url.PathEscape(opts.AccountID))
|
||||
if len(params) > 0 {
|
||||
path += "?" + params.Encode()
|
||||
}
|
||||
|
||||
fullURL := c.baseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
// Use a separate client without timeout since captures stream for their full duration.
|
||||
httpClient := &http.Client{}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
_, err = io.Copy(opts.Output, resp.Body)
|
||||
if err != nil && ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error {
|
||||
data, raw, err := c.fetch(ctx, path)
|
||||
if err != nil {
|
||||
|
||||
@@ -174,6 +174,8 @@ func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, pat
|
||||
h.handleClientStart(w, r, accountID)
|
||||
case "stop":
|
||||
h.handleClientStop(w, r, accountID)
|
||||
case "capture":
|
||||
h.handleCapture(w, r, accountID)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -632,6 +634,81 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
|
||||
})
|
||||
}
|
||||
|
||||
const maxCaptureDuration = 30 * time.Minute
|
||||
|
||||
// handleCapture streams a pcap or text packet capture for the given client.
|
||||
//
|
||||
// Query params:
|
||||
//
|
||||
// duration: capture duration (0 or absent = max, capped at 30m)
|
||||
// format: "text" for human-readable output (default: pcap)
|
||||
// filter: BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443")
|
||||
func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "client not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
duration := maxCaptureDuration
|
||||
if durationStr := r.URL.Query().Get("duration"); durationStr != "" {
|
||||
d, err := time.ParseDuration(durationStr)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid duration: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if d < 0 {
|
||||
http.Error(w, "duration must not be negative", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if d > 0 {
|
||||
duration = min(d, maxCaptureDuration)
|
||||
}
|
||||
}
|
||||
|
||||
filter := r.URL.Query().Get("filter")
|
||||
wantText := r.URL.Query().Get("format") == "text"
|
||||
verbose := r.URL.Query().Get("verbose") == "true"
|
||||
ascii := r.URL.Query().Get("ascii") == "true"
|
||||
|
||||
opts := nbembed.CaptureOptions{Filter: filter, Verbose: verbose, ASCII: ascii}
|
||||
if wantText {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
opts.TextOutput = w
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/vnd.tcpdump.pcap")
|
||||
w.Header().Set("Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=capture-%s.pcap", accountID))
|
||||
opts.Output = w
|
||||
}
|
||||
|
||||
cs, err := client.StartCapture(opts)
|
||||
if err != nil {
|
||||
http.Error(w, "start capture: "+err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer cs.Stop()
|
||||
|
||||
// Flush headers after setup succeeds so errors above can still set status codes.
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
timer := time.NewTimer(duration)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
cs.Stop()
|
||||
|
||||
stats := cs.Stats()
|
||||
h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped",
|
||||
accountID, stats.Packets, stats.Bytes, stats.Dropped)
|
||||
}
|
||||
|
||||
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) {
|
||||
if !wantJSON {
|
||||
http.Redirect(w, r, "/debug", http.StatusSeeOther)
|
||||
|
||||
@@ -203,15 +203,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||
type testProxyManager struct{}
|
||||
|
||||
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *nbproxy.Capabilities) error {
|
||||
func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) {
|
||||
return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: nbproxy.StatusConnected}, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
||||
func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -394,14 +394,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings from the snapshot - server sends each mapping individually
|
||||
mappingsByID := make(map[string]*proto.ProxyMapping)
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
for _, m := range msg.GetMapping() {
|
||||
mappingsByID[m.GetId()] = m
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive 2 mappings total
|
||||
@@ -441,12 +443,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
mappings := make([]*proto.ProxyMapping, 0)
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive the 2 mappings matching the cluster
|
||||
@@ -470,13 +474,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
clusterAddress := "test.proxy.io"
|
||||
proxyID := "test-proxy-reconnect"
|
||||
|
||||
// Helper to receive all mappings from a stream
|
||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
|
||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||
var mappings []*proto.ProxyMapping
|
||||
for i := 0; i < count; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
@@ -490,7 +496,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstMappings := receiveMappings(stream1, 2)
|
||||
firstMappings := receiveMappings(stream1)
|
||||
cancel1()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -506,7 +512,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondMappings := receiveMappings(stream2, 2)
|
||||
secondMappings := receiveMappings(stream2)
|
||||
|
||||
// Should receive the same mappings
|
||||
assert.Equal(t, len(firstMappings), len(secondMappings),
|
||||
@@ -572,12 +578,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to receive and apply all mappings
|
||||
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
applyMappings(msg.GetMapping())
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -666,12 +674,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
count := 0
|
||||
for i := 0; i < 2; i++ {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
count += len(msg.GetMapping())
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
@@ -686,3 +696,78 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
|
||||
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that
|
||||
// when a proxy reconnects before the old stream's cleanup runs, the new
|
||||
// connection is NOT removed by the stale defer.
|
||||
func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
proxyID := "test-proxy-race"
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
msg, err := stream1.Recv()
|
||||
require.NoError(t, err)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||
"proxy should be registered after first connection")
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
msg, err := stream2.Recv()
|
||||
require.NoError(t, err)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
cancel1()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
|
||||
"proxy should still be registered after old connection cleanup — old defer must not remove new connection")
|
||||
|
||||
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||
Id: "rp-1",
|
||||
AccountId: "test-account-1",
|
||||
Domain: "app1.test.proxy.io",
|
||||
}},
|
||||
})
|
||||
|
||||
msg, err := stream2.Recv()
|
||||
require.NoError(t, err, "new stream should still receive updates")
|
||||
require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping")
|
||||
assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId())
|
||||
}
|
||||
|
||||
@@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
operation := func() error {
|
||||
s.Logger.Debug("connecting to management mapping stream")
|
||||
|
||||
initialSyncDone = false
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
@@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var snapshotIDs map[types.ServiceID]struct{}
|
||||
if !*initialSyncDone {
|
||||
snapshotIDs = make(map[types.ServiceID]struct{})
|
||||
}
|
||||
|
||||
for {
|
||||
// Check for context completion to gracefully shutdown.
|
||||
select {
|
||||
@@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
|
||||
if !*initialSyncDone && msg.GetInitialSyncComplete() {
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
if !*initialSyncDone {
|
||||
for _, m := range msg.GetMapping() {
|
||||
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
|
||||
}
|
||||
if msg.GetInitialSyncComplete() {
|
||||
s.reconcileSnapshot(ctx, snapshotIDs)
|
||||
snapshotIDs = nil
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetInitialSyncComplete()
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
*initialSyncDone = true
|
||||
s.Logger.Info("Initial mapping sync complete")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconcileSnapshot removes local mappings that are absent from the snapshot.
|
||||
// This ensures services deleted while the proxy was disconnected get cleaned up.
|
||||
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
|
||||
s.portMu.RLock()
|
||||
var stale []*proto.ProxyMapping
|
||||
for svcID, mapping := range s.lastMappings {
|
||||
if _, ok := snapshotIDs[svcID]; !ok {
|
||||
stale = append(stale, mapping)
|
||||
}
|
||||
}
|
||||
s.portMu.RUnlock()
|
||||
|
||||
for _, mapping := range stale {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"service_id": mapping.GetId(),
|
||||
"domain": mapping.GetDomain(),
|
||||
}).Info("Removing stale mapping absent from snapshot")
|
||||
s.removeMapping(ctx, mapping)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
for _, mapping := range mappings {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
|
||||
227
proxy/snapshot_reconcile_test.go
Normal file
227
proxy/snapshot_reconcile_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
|
||||
// so we can verify it without triggering removeMapping (which requires full
|
||||
// server wiring). This keeps the test focused on the detection algorithm.
|
||||
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
|
||||
var stale []types.ServiceID
|
||||
for svcID := range lastMappings {
|
||||
if _, ok := snapshotIDs[svcID]; !ok {
|
||||
stale = append(stale, svcID)
|
||||
}
|
||||
}
|
||||
return stale
|
||||
}
|
||||
|
||||
// TestStaleDetection_PartialOverlap verifies that only services absent from
|
||||
// the snapshot are flagged as stale.
|
||||
func TestStaleDetection_PartialOverlap(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
"svc-stale-a": {Id: "svc-stale-a"},
|
||||
"svc-stale-b": {Id: "svc-stale-b"},
|
||||
}
|
||||
snapshot := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
"svc-3": {}, // new service, not in local
|
||||
}
|
||||
|
||||
stale := collectStaleIDs(local, snapshot)
|
||||
assert.Len(t, stale, 2)
|
||||
staleSet := make(map[types.ServiceID]struct{})
|
||||
for _, id := range stale {
|
||||
staleSet[id] = struct{}{}
|
||||
}
|
||||
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
|
||||
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
|
||||
}
|
||||
|
||||
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
|
||||
func TestStaleDetection_AllStale(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
}
|
||||
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
|
||||
assert.Len(t, stale, 2)
|
||||
}
|
||||
|
||||
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
|
||||
func TestStaleDetection_NoneStale(t *testing.T) {
|
||||
local := map[types.ServiceID]*proto.ProxyMapping{
|
||||
"svc-1": {Id: "svc-1"},
|
||||
"svc-2": {Id: "svc-2"},
|
||||
}
|
||||
snapshot := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
}
|
||||
stale := collectStaleIDs(local, snapshot)
|
||||
assert.Empty(t, stale)
|
||||
}
|
||||
|
||||
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
|
||||
func TestStaleDetection_EmptyLocal(t *testing.T) {
|
||||
stale := collectStaleIDs(
|
||||
map[types.ServiceID]*proto.ProxyMapping{},
|
||||
map[types.ServiceID]struct{}{"svc-1": {}},
|
||||
)
|
||||
assert.Empty(t, stale)
|
||||
}
|
||||
|
||||
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
|
||||
// local mappings are present in the snapshot (removeMapping is never called).
|
||||
func TestReconcileSnapshot_NoStale(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
|
||||
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
|
||||
|
||||
snapshotIDs := map[types.ServiceID]struct{}{
|
||||
"svc-1": {},
|
||||
"svc-2": {},
|
||||
}
|
||||
// This should not panic — no stale entries means removeMapping is never called.
|
||||
s.reconcileSnapshot(context.Background(), snapshotIDs)
|
||||
|
||||
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
|
||||
}
|
||||
|
||||
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
|
||||
// no local mappings.
|
||||
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
|
||||
assert.Empty(t, s.lastMappings)
|
||||
}
|
||||
|
||||
// --- handleMappingStream tests for batched snapshot ID accumulation ---
|
||||
|
||||
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
|
||||
// marked done only after the final InitialSyncComplete message, even when
|
||||
// the snapshot arrives in multiple batches.
|
||||
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
|
||||
checker := health.NewChecker(nil, nil)
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
healthChecker: checker,
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{}, // batch 1: no sync-complete
|
||||
{}, // batch 2: no sync-complete
|
||||
{InitialSyncComplete: true}, // batch 3: sync done
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, syncDone, "sync should be marked done after final batch")
|
||||
}
|
||||
|
||||
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
|
||||
// arriving after InitialSyncComplete do not trigger a second reconciliation.
|
||||
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
// Simulate state left over from a previous sync.
|
||||
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
|
||||
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{}, // post-sync empty message — must not reconcile
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := true // sync already completed in a previous stream
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, s.lastMappings, 2,
|
||||
"post-sync messages must not trigger reconciliation — all entries should survive")
|
||||
}
|
||||
|
||||
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
|
||||
// stream closes before sync completes, no reconciliation occurs.
|
||||
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||
|
||||
stream := &mockMappingStream{} // no messages → immediate EOF
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
|
||||
|
||||
_, hasStale := s.lastMappings["svc-stale"]
|
||||
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
|
||||
}
|
||||
|
||||
// mockErrRecvStream returns an error on the second Recv to verify
|
||||
// handleMappingStream returns without completing sync.
|
||||
type mockErrRecvStream struct {
|
||||
mockMappingStream
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
m.calls++
|
||||
if m.calls == 1 {
|
||||
return &proto.GetMappingUpdateResponse{}, nil
|
||||
}
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, syncDone)
|
||||
|
||||
_, hasStale := s.lastMappings["svc-stale"]
|
||||
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
|
||||
}
|
||||
Reference in New Issue
Block a user