From 3c28d297252e59ebbbdc00eaa6cb0f880807ba68 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 5 May 2026 18:12:18 +0300 Subject: [PATCH 01/27] [management] Map Entra oid claim as Dex user ID (#6067) --- idp/dex/connector.go | 62 ++++++++---- idp/dex/connector_test.go | 205 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 246 insertions(+), 21 deletions(-) create mode 100644 idp/dex/connector_test.go diff --git a/idp/dex/connector.go b/idp/dex/connector.go index 8aba92999..fb20fdcc3 100644 --- a/idp/dex/connector.go +++ b/idp/dex/connector.go @@ -89,21 +89,33 @@ func (p *Provider) ListConnectors(ctx context.Context) ([]*ConnectorConfig, erro } // UpdateConnector updates an existing connector in Dex storage. -// It merges incoming updates with existing values to prevent data loss on partial updates. +// It overlays user-mutable config fields (issuer, clientID, clientSecret, +// redirectURI) onto the stored connector config, and updates the connector name +// when cfg.Name is set. Empty fields on cfg leave stored values unchanged, so +// partial updates preserve create-time defaults such as scopes, claimMapping, +// and userIDKey. func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) error { if err := p.storage.UpdateConnector(ctx, cfg.ID, func(old storage.Connector) (storage.Connector, error) { - oldCfg, err := p.parseStorageConnector(old) - if err != nil { - return storage.Connector{}, fmt.Errorf("failed to parse existing connector: %w", err) + if cfg.Type != "" && cfg.Type != inferIdentityProviderType(old.Type, cfg.ID, nil) { + return storage.Connector{}, errors.New("connector type change not allowed") } - mergeConnectorConfig(cfg, oldCfg) - - storageConn, err := p.buildStorageConnector(cfg) + configData, err := overlayConnectorConfig(old.Config, cfg) if err != nil { - return storage.Connector{}, fmt.Errorf("failed to build connector: %w", err) + return storage.Connector{}, fmt.Errorf("failed to overlay connector config: %w", err) } - return storageConn, nil + + name := cfg.Name + if name == "" { + name = old.Name + } + + return storage.Connector{ + ID: cfg.ID, + Type: old.Type, + Name: name, + Config: configData, + }, nil }); err != nil { return fmt.Errorf("failed to update connector: %w", err) } @@ -112,23 +124,27 @@ func (p *Provider) UpdateConnector(ctx context.Context, cfg *ConnectorConfig) er return nil } -// mergeConnectorConfig preserves existing values for empty fields in the update. -func mergeConnectorConfig(cfg, oldCfg *ConnectorConfig) { - if cfg.ClientSecret == "" { - cfg.ClientSecret = oldCfg.ClientSecret +// overlayConnectorConfig writes only the user-mutable fields onto the existing +// stored config, preserving every other field (scopes, claimMapping, userIDKey, +// insecure flags, etc.). Empty fields on cfg leave the existing value alone. +func overlayConnectorConfig(oldConfig []byte, cfg *ConnectorConfig) ([]byte, error) { + var m map[string]any + if err := decodeConnectorConfig(oldConfig, &m); err != nil { + return nil, err } - if cfg.RedirectURI == "" { - cfg.RedirectURI = oldCfg.RedirectURI + if cfg.Issuer != "" { + m["issuer"] = cfg.Issuer } - if cfg.Issuer == "" && cfg.Type == oldCfg.Type { - cfg.Issuer = oldCfg.Issuer + if cfg.ClientID != "" { + m["clientID"] = cfg.ClientID } - if cfg.ClientID == "" { - cfg.ClientID = oldCfg.ClientID + if cfg.ClientSecret != "" { + m["clientSecret"] = cfg.ClientSecret } - if cfg.Name == "" { - cfg.Name = oldCfg.Name + if cfg.RedirectURI != "" { + m["redirectURI"] = cfg.RedirectURI } + return encodeConnectorConfig(m) } // DeleteConnector removes a connector from Dex storage. @@ -216,6 +232,10 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, oidcConfig["getUserInfo"] = true case "entra": oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"} + // Use the Entra Object ID (oid) instead of the default OIDC sub claim. + // Entra issues sub as a per-app pairwise identifier that does not match + // the stable Object ID. + oidcConfig["userIDKey"] = "oid" case "okta": oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} case "pocketid": diff --git a/idp/dex/connector_test.go b/idp/dex/connector_test.go new file mode 100644 index 000000000..4253e02b7 --- /dev/null +++ b/idp/dex/connector_test.go @@ -0,0 +1,205 @@ +package dex + +import ( + "context" + "encoding/json" + "log/slog" + "os" + "path/filepath" + "testing" + + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestProvider(t *testing.T) (*Provider, func()) { + t.Helper() + tmpDir, err := os.MkdirTemp("", "dex-connector-test-*") + require.NoError(t, err) + + logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + s, err := (&sql.SQLite3{File: filepath.Join(tmpDir, "dex.db")}).Open(logger) + require.NoError(t, err) + + return &Provider{storage: s, logger: logger}, func() { + _ = s.Close() + _ = os.RemoveAll(tmpDir) + } +} + +func TestBuildOIDCConnectorConfig_EntraSetsUserIDKey(t *testing.T) { + cfg := &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/tid/v2.0", + ClientID: "client-id", + ClientSecret: "client-secret", + } + data, err := buildOIDCConnectorConfig(cfg, "https://example.com/oauth2/callback") + require.NoError(t, err) + + var m map[string]any + require.NoError(t, json.Unmarshal(data, &m)) + + assert.Equal(t, "oid", m["userIDKey"], "entra connectors must default userIDKey to oid") + assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"]) +} + +func TestBuildOIDCConnectorConfig_NonEntraDoesNotSetUserIDKey(t *testing.T) { + // ensures the Entra userIDKey override does not leak into other OIDC providers, + // which already use a stable sub claim. + for _, typ := range []string{"oidc", "zitadel", "okta", "pocketid", "authentik", "keycloak", "adfs"} { + t.Run(typ, func(t *testing.T) { + data, err := buildOIDCConnectorConfig(&ConnectorConfig{Type: typ}, "https://example.com/oauth2/callback") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(data, &m)) + _, ok := m["userIDKey"] + assert.False(t, ok, "%s connectors must not have userIDKey set", typ) + }) + } +} + +func TestUpdateConnector_PreservesCreateTimeDefaults(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + created, err := p.CreateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/tid/v2.0", + ClientID: "client-id", + ClientSecret: "old-secret", + RedirectURI: "https://example.com/oauth2/callback", + }) + require.NoError(t, err) + require.Equal(t, "entra-test", created.ID) + + // Rotate only the client secret. + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Type: "entra", + ClientSecret: "new-secret", + }) + require.NoError(t, err) + + conn, err := p.storage.GetConnector(ctx, "entra-test") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + + assert.Equal(t, "new-secret", m["clientSecret"], "clientSecret should be rotated") + assert.Equal(t, "client-id", m["clientID"], "clientID must survive (overlay should leave it alone)") + assert.Equal(t, "https://login.microsoftonline.com/tid/v2.0", m["issuer"]) + assert.Equal(t, "oid", m["userIDKey"], "userIDKey must survive update") + assert.Equal(t, map[string]any{"email": "preferred_username"}, m["claimMapping"], "claimMapping must survive update") +} + +func TestUpdateConnector_DoesNotAddUserIDKeyToExistingConnector(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + // Seed a connector directly into storage without userIDKey + preFixConfig, err := json.Marshal(map[string]any{ + "issuer": "https://login.microsoftonline.com/tid/v2.0", + "clientID": "client-id", + "clientSecret": "old-secret", + "redirectURI": "https://example.com/oauth2/callback", + "scopes": []string{"openid", "profile", "email"}, + "claimMapping": map[string]string{"email": "preferred_username"}, + }) + require.NoError(t, err) + + require.NoError(t, p.storage.CreateConnector(ctx, storage.Connector{ + ID: "entra-prefix", + Type: "oidc", + Name: "Entra", + Config: preFixConfig, + })) + + // Rotate client secret via UpdateConnector. + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-prefix", + Type: "entra", + ClientSecret: "new-secret", + }) + require.NoError(t, err) + + conn, err := p.storage.GetConnector(ctx, "entra-prefix") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + + assert.Equal(t, "new-secret", m["clientSecret"]) + _, has := m["userIDKey"] + assert.False(t, has, "userIDKey must not be auto-added to a connector that did not have it before") +} + +func TestUpdateConnector_RejectsTypeChange(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + _, err := p.CreateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/tid/v2.0", + ClientID: "client-id", + ClientSecret: "secret", + RedirectURI: "https://example.com/oauth2/callback", + }) + require.NoError(t, err) + + // Attempt to switch the connector to okta. + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Type: "okta", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "connector type change not allowed") + + // stored connector type/config unchanged after the rejected update. + conn, err := p.storage.GetConnector(ctx, "entra-test") + require.NoError(t, err) + assert.Equal(t, "oidc", conn.Type) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + assert.Equal(t, "oid", m["userIDKey"]) +} + +func TestUpdateConnector_AllowsSameTypeUpdate(t *testing.T) { + ctx := context.Background() + p, cleanup := newTestProvider(t) + defer cleanup() + + _, err := p.CreateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Name: "Entra", + Type: "entra", + Issuer: "https://login.microsoftonline.com/old/v2.0", + ClientID: "client-id", + ClientSecret: "secret", + RedirectURI: "https://example.com/oauth2/callback", + }) + require.NoError(t, err) + + err = p.UpdateConnector(ctx, &ConnectorConfig{ + ID: "entra-test", + Type: "entra", + Issuer: "https://login.microsoftonline.com/new/v2.0", + }) + require.NoError(t, err) + + conn, err := p.storage.GetConnector(ctx, "entra-test") + require.NoError(t, err) + var m map[string]any + require.NoError(t, json.Unmarshal(conn.Config, &m)) + assert.Equal(t, "https://login.microsoftonline.com/new/v2.0", m["issuer"]) +} From cfb1b3fe31c37db79a67434ab620ddb0eca41faf Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 5 May 2026 18:40:42 +0200 Subject: [PATCH 02/27] [proxy] consolidate mapping update (#6072) --- management/internals/shared/grpc/proxy.go | 118 ++++++--- .../shared/grpc/proxy_snapshot_test.go | 174 ++++++++++++++ .../internals/shared/grpc/proxy_test.go | 3 + proxy/management_integration_test.go | 50 ++-- proxy/server.go | 45 +++- proxy/snapshot_reconcile_test.go | 227 ++++++++++++++++++ 6 files changed, 559 insertions(+), 58 deletions(-) create mode 100644 management/internals/shared/grpc/proxy_snapshot_test.go create mode 100644 proxy/snapshot_reconcile_test.go diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index d811a0f69..6763a3ba3 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -11,6 +11,8 @@ import ( "fmt" "net/http" "net/url" + "os" + "strconv" "strings" "sync" "time" @@ -82,11 +84,40 @@ type ProxyServiceServer struct { // Store for PKCE verifiers pkceVerifierStore *PKCEVerifierStore + // tokenTTL is the lifetime of one-time tokens generated for proxy + // authentication. Defaults to defaultProxyTokenTTL when zero. + tokenTTL time.Duration + + // snapshotBatchSize is the number of mappings per gRPC message during + // initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE. + snapshotBatchSize int + cancel context.CancelFunc } const pkceVerifierTTL = 10 * time.Minute +const defaultProxyTokenTTL = 5 * time.Minute + +const defaultSnapshotBatchSize = 500 + +func snapshotBatchSizeFromEnv() int { + if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + return defaultSnapshotBatchSize +} + +// proxyTokenTTL returns the configured token TTL or the default when unset. +func (s *ProxyServiceServer) proxyTokenTTL() time.Duration { + if s.tokenTTL > 0 { + return s.tokenTTL + } + return defaultProxyTokenTTL +} + // proxyConnection represents a connected proxy type proxyConnection struct { proxyID string @@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + snapshotBatchSize: snapshotBatchSizeFromEnv(), cancel: cancel, } go s.cleanupStaleProxies(ctx) @@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest cancel: cancel, } - s.connectedProxies.Store(proxyID, conn) - if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) - } - // Register proxy in database with capabilities var caps *proxy.Capabilities if c := req.GetCapabilities(); c != nil { @@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) if err != nil { log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) - s.connectedProxies.CompareAndDelete(proxyID, conn) - if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { - log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) - } + cancel() return status.Errorf(codes.Internal, "register proxy in database: %v", err) } + s.connectedProxies.Store(proxyID, conn) + if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) + } + + if err := s.sendSnapshot(ctx, conn); err != nil { + if s.connectedProxies.CompareAndDelete(proxyID, conn) { + if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr) + } + } + cancel() + if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr) + } + return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) + } + + errChan := make(chan error, 2) + go s.sender(conn, errChan) + log.WithFields(log.Fields{ "proxy_id": proxyID, "session_id": sessionID, @@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) }() - if err := s.sendSnapshot(ctx, conn); err != nil { - return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) - } - - errChan := make(chan error, 2) - go s.sender(conn, errChan) - go s.heartbeat(connCtx, proxyRecord) select { @@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return err } + // Send mappings in batches to reduce per-message gRPC overhead while + // staying well within the default 4 MB message size limit. + for i := 0; i < len(mappings); i += s.snapshotBatchSize { + end := i + s.snapshotBatchSize + if end > len(mappings) { + end = len(mappings) + } + if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: mappings[i:end], + InitialSyncComplete: end == len(mappings), + }); err != nil { + return fmt.Errorf("send snapshot batch: %w", err) + } + } + if len(mappings) == 0 { if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ InitialSyncComplete: true, }); err != nil { return fmt.Errorf("send snapshot completion: %w", err) } - return nil - } - - for i, m := range mappings { - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{m}, - InitialSyncComplete: i == len(mappings)-1, - }); err != nil { - return fmt.Errorf("send proxy mapping: %w", err) - } } return nil @@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * continue } - token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute) + token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL()) if err != nil { - log.WithFields(log.Fields{ - "service": service.Name, - "account": service.AccountID, - }).WithError(err).Error("failed to generate auth token for snapshot") - continue + return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err) } m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) @@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes conn := value.(*proxyConnection) resp := s.perProxyMessage(update, conn.proxyID) if resp == nil { + log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID) + conn.cancel() return true } select { case conn.sendChan <- resp: log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: - log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) + log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID) + conn.cancel() } return true }) @@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd } msg := s.perProxyMessage(updateResponse, proxyID) if msg == nil { + log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr) + conn.cancel() continue } select { case conn.sendChan <- msg: log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) default: - log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) + log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr) + conn.cancel() } } } @@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo // perProxyMessage returns a copy of update with a fresh one-time token for // create/update operations. For delete operations the original mapping is // used unchanged because proxies do not need to authenticate for removal. -// Returns nil if token generation fails (the proxy should be skipped). +// Returns nil if token generation fails; the caller must disconnect the +// proxy so it can resync via a fresh snapshot on reconnect. func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse { resp := make([]*proto.ProxyMapping, 0, len(update.Mapping)) for _, mapping := range update.Mapping { @@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo continue } - token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute) + token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL()) if err != nil { log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) return nil diff --git a/management/internals/shared/grpc/proxy_snapshot_test.go b/management/internals/shared/grpc/proxy_snapshot_test.go new file mode 100644 index 000000000..e0c7425c5 --- /dev/null +++ b/management/internals/shared/grpc/proxy_snapshot_test.go @@ -0,0 +1,174 @@ +package grpc + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// recordingStream captures all messages sent via Send so tests can inspect +// batching behaviour without a real gRPC transport. +type recordingStream struct { + grpc.ServerStream + messages []*proto.GetMappingUpdateResponse +} + +func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error { + s.messages = append(s.messages, m) + return nil +} + +func (s *recordingStream) Context() context.Context { return context.Background() } +func (s *recordingStream) SetHeader(metadata.MD) error { return nil } +func (s *recordingStream) SendHeader(metadata.MD) error { return nil } +func (s *recordingStream) SetTrailer(metadata.MD) {} +func (s *recordingStream) SendMsg(any) error { return nil } +func (s *recordingStream) RecvMsg(any) error { return nil } + +// makeServices creates n enabled services assigned to the given cluster. +func makeServices(n int, cluster string) []*rpservice.Service { + services := make([]*rpservice.Service, n) + for i := range n { + services[i] = &rpservice.Service{ + ID: fmt.Sprintf("svc-%d", i), + AccountID: "acct-1", + Name: fmt.Sprintf("svc-%d", i), + Domain: fmt.Sprintf("svc-%d.example.com", i), + ProxyCluster: cluster, + Enabled: true, + Targets: []*rpservice.Target{ + {TargetType: rpservice.TargetTypeHost, TargetId: "host-1"}, + }, + } + } + return services +} + +func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer { + t.Helper() + s := &ProxyServiceServer{ + tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)), + snapshotBatchSize: batchSize, + } + s.SetProxyController(newTestProxyController()) + return s +} + +func TestSendSnapshot_BatchesMappings(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 7 // 3 + 3 + 1 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + stream: stream, + } + + err := s.sendSnapshot(context.Background(), conn) + require.NoError(t, err) + + // Expect ceil(7/3) = 3 messages + require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages") + + assert.Len(t, stream.messages[0].Mapping, 3) + assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete") + + assert.Len(t, stream.messages[1].Mapping, 3) + assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete") + + assert.Len(t, stream.messages[2].Mapping, 1) + assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete") + + // Verify all service IDs are present exactly once + seen := make(map[string]bool) + for _, msg := range stream.messages { + for _, m := range msg.Mapping { + assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id) + seen[m.Id] = true + } + } + assert.Len(t, seen, totalServices) +} + +func TestSendSnapshot_ExactBatchMultiple(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 6 // exactly 2 batches + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 2) + + assert.Len(t, stream.messages[0].Mapping, 3) + assert.False(t, stream.messages[0].InitialSyncComplete) + + assert.Len(t, stream.messages[1].Mapping, 3) + assert.True(t, stream.messages[1].InitialSyncComplete) +} + +func TestSendSnapshot_SingleBatch(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 100 + const totalServices = 5 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 1, "all mappings should fit in one batch") + assert.Len(t, stream.messages[0].Mapping, totalServices) + assert.True(t, stream.messages[0].InitialSyncComplete) +} + +func TestSendSnapshot_EmptySnapshot(t *testing.T) { + const cluster = "cluster.example.com" + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil) + + s := newSnapshotTestServer(t, 500) + s.serviceManager = mgr + + stream := &recordingStream{} + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete") + assert.Empty(t, stream.messages[0].Mapping) + assert.True(t, stream.messages[0].InitialSyncComplete) +} diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index de4e96d93..5a7a457df 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan // registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities. func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse { ch := make(chan *proto.GetMappingUpdateResponse, 10) + ctx, cancel := context.WithCancel(context.Background()) conn := &proxyConnection{ proxyID: proxyID, address: clusterAddr, capabilities: caps, sendChan: ch, + ctx: ctx, + cancel: cancel, } s.connectedProxies.Store(proxyID, conn) diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index e9eae3210..99bbdad0c 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) { }) require.NoError(t, err) - // Receive all mappings from the snapshot - server sends each mapping individually mappingsByID := make(map[string]*proto.ProxyMapping) - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) for _, m := range msg.GetMapping() { mappingsByID[m.GetId()] = m } + if msg.GetInitialSyncComplete() { + break + } } // Should receive 2 mappings total @@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) { }) require.NoError(t, err) - // Receive all mappings - server sends each mapping individually mappings := make([]*proto.ProxyMapping, 0) - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } } // Should receive the 2 mappings matching the cluster @@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) clusterAddress := "test.proxy.io" proxyID := "test-proxy-reconnect" - // Helper to receive all mappings from a stream - receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping { + receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping { var mappings []*proto.ProxyMapping - for i := 0; i < count; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } } return mappings } @@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) }) require.NoError(t, err) - firstMappings := receiveMappings(stream1, 2) + firstMappings := receiveMappings(stream1) cancel1() time.Sleep(100 * time.Millisecond) @@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) }) require.NoError(t, err) - secondMappings := receiveMappings(stream2, 2) + secondMappings := receiveMappings(stream2) // Should receive the same mappings assert.Equal(t, len(firstMappings), len(secondMappings), @@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T } } - // Helper to receive and apply all mappings receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) { - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) applyMappings(msg.GetMapping()) + if msg.GetInitialSyncComplete() { + break + } } } @@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) }) require.NoError(t, err) - // Receive all mappings - server sends each mapping individually count := 0 - for i := 0; i < 2; i++ { + for { msg, err := stream.Recv() require.NoError(t, err) count += len(msg.GetMapping()) + if msg.GetInitialSyncComplete() { + break + } } mu.Lock() @@ -681,9 +691,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) }) require.NoError(t, err) - for i := 0; i < 2; i++ { - _, err := stream1.Recv() + for { + msg, err := stream1.Recv() require.NoError(t, err) + if msg.GetInitialSyncComplete() { + break + } } require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, @@ -699,9 +712,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) }) require.NoError(t, err) - for i := 0; i < 2; i++ { - _, err := stream2.Recv() + for { + msg, err := stream2.Recv() require.NoError(t, err) + if msg.GetInitialSyncComplete() { + break + } } cancel1() diff --git a/proxy/server.go b/proxy/server.go index fbd0d058e..6980e1df1 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr operation := func() error { s.Logger.Debug("connecting to management mapping stream") + initialSyncDone = false + if s.healthChecker != nil { s.healthChecker.SetManagementConnected(false) } @@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr return ctx.Err() } + var snapshotIDs map[types.ServiceID]struct{} + if !*initialSyncDone { + snapshotIDs = make(map[types.ServiceID]struct{}) + } + for { // Check for context completion to gracefully shutdown. select { @@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr s.processMappings(ctx, msg.GetMapping()) s.Logger.Debug("Processing mapping update completed") - if !*initialSyncDone && msg.GetInitialSyncComplete() { - if s.healthChecker != nil { - s.healthChecker.SetInitialSyncComplete() + if !*initialSyncDone { + for _, m := range msg.GetMapping() { + snapshotIDs[types.ServiceID(m.GetId())] = struct{}{} + } + if msg.GetInitialSyncComplete() { + s.reconcileSnapshot(ctx, snapshotIDs) + snapshotIDs = nil + if s.healthChecker != nil { + s.healthChecker.SetInitialSyncComplete() + } + *initialSyncDone = true + s.Logger.Info("Initial mapping sync complete") } - *initialSyncDone = true - s.Logger.Info("Initial mapping sync complete") } } } } +// reconcileSnapshot removes local mappings that are absent from the snapshot. +// This ensures services deleted while the proxy was disconnected get cleaned up. +func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) { + s.portMu.RLock() + var stale []*proto.ProxyMapping + for svcID, mapping := range s.lastMappings { + if _, ok := snapshotIDs[svcID]; !ok { + stale = append(stale, mapping) + } + } + s.portMu.RUnlock() + + for _, mapping := range stale { + s.Logger.WithFields(log.Fields{ + "service_id": mapping.GetId(), + "domain": mapping.GetDomain(), + }).Info("Removing stale mapping absent from snapshot") + s.removeMapping(ctx, mapping) + } +} + func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { for _, mapping := range mappings { s.Logger.WithFields(log.Fields{ diff --git a/proxy/snapshot_reconcile_test.go b/proxy/snapshot_reconcile_test.go new file mode 100644 index 000000000..042d8df77 --- /dev/null +++ b/proxy/snapshot_reconcile_test.go @@ -0,0 +1,227 @@ +package proxy + +import ( + "context" + "io" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/health" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot +// so we can verify it without triggering removeMapping (which requires full +// server wiring). This keeps the test focused on the detection algorithm. +func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID { + var stale []types.ServiceID + for svcID := range lastMappings { + if _, ok := snapshotIDs[svcID]; !ok { + stale = append(stale, svcID) + } + } + return stale +} + +// TestStaleDetection_PartialOverlap verifies that only services absent from +// the snapshot are flagged as stale. +func TestStaleDetection_PartialOverlap(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + "svc-stale-a": {Id: "svc-stale-a"}, + "svc-stale-b": {Id: "svc-stale-b"}, + } + snapshot := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + "svc-3": {}, // new service, not in local + } + + stale := collectStaleIDs(local, snapshot) + assert.Len(t, stale, 2) + staleSet := make(map[types.ServiceID]struct{}) + for _, id := range stale { + staleSet[id] = struct{}{} + } + assert.Contains(t, staleSet, types.ServiceID("svc-stale-a")) + assert.Contains(t, staleSet, types.ServiceID("svc-stale-b")) +} + +// TestStaleDetection_AllStale verifies an empty snapshot flags everything. +func TestStaleDetection_AllStale(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + } + stale := collectStaleIDs(local, map[types.ServiceID]struct{}{}) + assert.Len(t, stale, 2) +} + +// TestStaleDetection_NoneStale verifies full overlap produces no stale entries. +func TestStaleDetection_NoneStale(t *testing.T) { + local := map[types.ServiceID]*proto.ProxyMapping{ + "svc-1": {Id: "svc-1"}, + "svc-2": {Id: "svc-2"}, + } + snapshot := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + } + stale := collectStaleIDs(local, snapshot) + assert.Empty(t, stale) +} + +// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty. +func TestStaleDetection_EmptyLocal(t *testing.T) { + stale := collectStaleIDs( + map[types.ServiceID]*proto.ProxyMapping{}, + map[types.ServiceID]struct{}{"svc-1": {}}, + ) + assert.Empty(t, stale) +} + +// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all +// local mappings are present in the snapshot (removeMapping is never called). +func TestReconcileSnapshot_NoStale(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"} + s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"} + + snapshotIDs := map[types.ServiceID]struct{}{ + "svc-1": {}, + "svc-2": {}, + } + // This should not panic — no stale entries means removeMapping is never called. + s.reconcileSnapshot(context.Background(), snapshotIDs) + + assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot") +} + +// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with +// no local mappings. +func TestReconcileSnapshot_EmptyLocal(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}}) + assert.Empty(t, s.lastMappings) +} + +// --- handleMappingStream tests for batched snapshot ID accumulation --- + +// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is +// marked done only after the final InitialSyncComplete message, even when +// the snapshot arrives in multiple batches. +func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) { + checker := health.NewChecker(nil, nil) + s := &Server{ + Logger: log.StandardLogger(), + healthChecker: checker, + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + stream := &mockMappingStream{ + messages: []*proto.GetMappingUpdateResponse{ + {}, // batch 1: no sync-complete + {}, // batch 2: no sync-complete + {InitialSyncComplete: true}, // batch 3: sync done + }, + } + + syncDone := false + err := s.handleMappingStream(context.Background(), stream, &syncDone) + assert.NoError(t, err) + assert.True(t, syncDone, "sync should be marked done after final batch") +} + +// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages +// arriving after InitialSyncComplete do not trigger a second reconciliation. +func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + // Simulate state left over from a previous sync. + s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"} + s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"} + + stream := &mockMappingStream{ + messages: []*proto.GetMappingUpdateResponse{ + {}, // post-sync empty message — must not reconcile + }, + } + + syncDone := true // sync already completed in a previous stream + err := s.handleMappingStream(context.Background(), stream, &syncDone) + require.NoError(t, err) + + assert.Len(t, s.lastMappings, 2, + "post-sync messages must not trigger reconciliation — all entries should survive") +} + +// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the +// stream closes before sync completes, no reconciliation occurs. +func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"} + + stream := &mockMappingStream{} // no messages → immediate EOF + + syncDone := false + err := s.handleMappingStream(context.Background(), stream, &syncDone) + assert.NoError(t, err) + assert.False(t, syncDone, "sync should not be marked done on immediate EOF") + + _, hasStale := s.lastMappings["svc-stale"] + assert.True(t, hasStale, "stale mapping should remain when sync never completed") +} + +// mockErrRecvStream returns an error on the second Recv to verify +// handleMappingStream returns without completing sync. +type mockErrRecvStream struct { + mockMappingStream + calls int +} + +func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) { + m.calls++ + if m.calls == 1 { + return &proto.GetMappingUpdateResponse{}, nil + } + return nil, io.ErrUnexpectedEOF +} + +func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) { + s := &Server{ + Logger: log.StandardLogger(), + routerReady: closedChan(), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + } + + s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"} + + syncDone := false + err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone) + assert.Error(t, err) + assert.False(t, syncDone) + + _, hasStale := s.lastMappings["svc-stale"] + assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error") +} From b19b7464eac5c58bb6a6780a033398a27f3d772f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 5 May 2026 18:48:51 +0200 Subject: [PATCH 03/27] [management] fix flaky invite token test (#6077) --- management/server/types/user_invite_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/management/server/types/user_invite_test.go b/management/server/types/user_invite_test.go index 09dae3800..c77fb89e2 100644 --- a/management/server/types/user_invite_test.go +++ b/management/server/types/user_invite_test.go @@ -144,8 +144,11 @@ func TestValidateInviteToken_ModifiedToken(t *testing.T) { _, plainToken, err := GenerateInviteToken() require.NoError(t, err) - // Modify one character in the secret part - modifiedToken := plainToken[:5] + "X" + plainToken[6:] + replacement := "X" + if plainToken[5] == 'X' { + replacement = "Y" + } + modifiedToken := plainToken[:5] + replacement + plainToken[6:] err = ValidateInviteToken(modifiedToken) require.Error(t, err) } From bfeb9b19ecbe03cf1d2b3f44258a84b3dfe02868 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 6 May 2026 13:07:01 +0200 Subject: [PATCH 04/27] [management] remove permissions from geolocations api (#6091) --- .../handlers/policies/geolocations_handler.go | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index a2d656a47..eea31ebc6 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -7,11 +7,8 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -45,11 +42,6 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa // getAllCountries retrieves a list of all countries func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { - if err := l.authenticateUser(r); err != nil { - util.WriteError(r.Context(), err, w) - return - } - if l.geolocationManager == nil { // TODO: update error message to include geo db self hosted doc link when ready util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) @@ -71,11 +63,6 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req // getCitiesByCountry retrieves a list of cities based on the given country code func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { - if err := l.authenticateUser(r); err != nil { - util.WriteError(r.Context(), err, w) - return - } - vars := mux.Vars(r) countryCode := vars["country"] if !countryCodeRegex.MatchString(countryCode) { @@ -102,27 +89,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http. util.WriteJSONObject(r.Context(), w, cities) } -func (l *geolocationsHandler) authenticateUser(r *http.Request) error { - ctx := r.Context() - - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - return err - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read) - if err != nil { - return status.NewPermissionValidationError(err) - } - - if !allowed { - return status.NewPermissionDeniedError() - } - return nil -} - func toCountryResponse(country geolocation.Country) api.Country { return api.Country{ CountryName: country.CountryName, From 71a400f90fc522389739e70acdc6f800ed1e76c6 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 6 May 2026 20:23:43 +0900 Subject: [PATCH 05/27] [client] Include MTU and SSH auth/JWT cache config in debug bundle (#6071) --- client/internal/debug/debug.go | 7 ++ client/internal/debug/debug_test.go | 139 +++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 5 deletions(-) diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 90560d028..0ad1401e7 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -607,6 +607,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) if g.internalConfig.EnableSSHRemotePortForwarding != nil { configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) } + if g.internalConfig.DisableSSHAuth != nil { + configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth)) + } + if g.internalConfig.SSHJWTCacheTTL != nil { + configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL)) + } configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) @@ -633,6 +639,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) } configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) + configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU)) } func (g *BundleGenerator) addProf() (err error) { diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index 6b5bb911c..05d51e593 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -5,16 +5,21 @@ import ( "bytes" "encoding/json" "net" + "net/url" "os" "path/filepath" + "reflect" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/configs" + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) @@ -471,8 +476,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) { anonymize: false, input: map[string]any{ jsonKeyServiceEnv: map[string]any{ - "HOME": "/root", - "PATH": "/usr/bin", + "HOME": "/root", + "PATH": "/usr/bin", "NB_LOG_LEVEL": "debug", }, }, @@ -489,9 +494,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) { anonymize: false, input: map[string]any{ jsonKeyServiceEnv: map[string]any{ - "NB_SETUP_KEY": "abc123", - "NB_API_TOKEN": "tok_xyz", - "NB_LOG_LEVEL": "info", + "NB_SETUP_KEY": "abc123", + "NB_API_TOKEN": "tok_xyz", + "NB_LOG_LEVEL": "info", }, }, check: func(t *testing.T, params map[string]any) { @@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") } + +// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in +// profilemanager.Config is either rendered in the debug bundle or explicitly +// excluded. When a new field is added to Config, this test fails until the +// developer either dumps it in addConfig/addCommonConfigFields or adds it to +// the excluded set with a justification. +func TestAddConfig_AllFieldsCovered(t *testing.T) { + excluded := map[string]string{ + "PrivateKey": "sensitive: WireGuard private key", + "PreSharedKey": "sensitive: WireGuard pre-shared key", + "SSHKey": "sensitive: SSH private key", + "ClientCertKeyPair": "non-config: parsed cert pair, not serialized", + } + + mURL, _ := url.Parse("https://api.example.com:443") + aURL, _ := url.Parse("https://admin.example.com:443") + bTrue := true + iVal := 42 + cfg := &profilemanager.Config{ + PrivateKey: "priv", + PreSharedKey: "psk", + ManagementURL: mURL, + AdminURL: aURL, + WgIface: "wt0", + WgPort: 51820, + NetworkMonitor: &bTrue, + IFaceBlackList: []string{"eth0"}, + DisableIPv6Discovery: true, + RosenpassEnabled: true, + RosenpassPermissive: true, + ServerSSHAllowed: &bTrue, + EnableSSHRoot: &bTrue, + EnableSSHSFTP: &bTrue, + EnableSSHLocalPortForwarding: &bTrue, + EnableSSHRemotePortForwarding: &bTrue, + DisableSSHAuth: &bTrue, + SSHJWTCacheTTL: &iVal, + DisableClientRoutes: true, + DisableServerRoutes: true, + DisableDNS: true, + DisableFirewall: true, + BlockLANAccess: true, + BlockInbound: true, + DisableNotifications: &bTrue, + DNSLabels: domain.List{}, + SSHKey: "sshkey", + NATExternalIPs: []string{"1.2.3.4"}, + CustomDNSAddress: "1.1.1.1:53", + DisableAutoConnect: true, + DNSRouteInterval: 5 * time.Second, + ClientCertPath: "/tmp/cert", + ClientCertKeyPath: "/tmp/key", + LazyConnectionEnabled: true, + MTU: 1280, + } + + for _, anonymize := range []bool{false, true} { + t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) { + g := &BundleGenerator{ + anonymizer: newAnonymizerForTest(), + internalConfig: cfg, + anonymize: anonymize, + } + + var sb strings.Builder + g.addCommonConfigFields(&sb) + rendered := sb.String() + renderAddConfigSpecific(g) + + val := reflect.ValueOf(cfg).Elem() + typ := val.Type() + var missing []string + for i := 0; i < typ.NumField(); i++ { + name := typ.Field(i).Name + if _, ok := excluded[name]; ok { + continue + } + if !strings.Contains(rendered, name+":") { + missing = append(missing, name) + } + } + if len(missing) > 0 { + t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+ + "Either render the field in addCommonConfigFields/addConfig, "+ + "or add it to the excluded map with a justification.", missing) + } + }) + } +} + +// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize +// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress). +// addCommonConfigFields covers the rest. Keeping this in the test mirrors the +// production shape without needing to write an actual zip. +func renderAddConfigSpecific(g *BundleGenerator) string { + var sb strings.Builder + if g.anonymize { + if g.internalConfig.ManagementURL != nil { + sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n") + } + if g.internalConfig.AdminURL != nil { + sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n") + } + sb.WriteString("NATExternalIPs: x\n") + if g.internalConfig.CustomDNSAddress != "" { + sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n") + } + } else { + if g.internalConfig.ManagementURL != nil { + sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n") + } + if g.internalConfig.AdminURL != nil { + sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n") + } + sb.WriteString("NATExternalIPs: x\n") + if g.internalConfig.CustomDNSAddress != "" { + sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n") + } + } + return sb.String() +} + +func newAnonymizerForTest() *anonymize.Anonymizer { + return anonymize.NewAnonymizer(anonymize.DefaultAddresses()) +} From f532976e05879f5a3cb56e016449e1a404457ac5 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 6 May 2026 20:42:47 +0900 Subject: [PATCH 06/27] [client] Add public key to debug bundle config.txt (#6092) --- client/internal/debug/debug.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 0ad1401e7..0a12a5326 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -21,6 +21,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/protobuf/encoding/protojson" "github.com/netbirdio/netbird/client/anonymize" @@ -583,6 +584,9 @@ func isSensitiveEnvVar(key string) bool { func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) { configContent.WriteString("NetBird Client Configuration:\n\n") + if key, err := wgtypes.ParseKey(g.internalConfig.PrivateKey); err == nil { + configContent.WriteString(fmt.Sprintf("PublicKey: %s\n", key.PublicKey().String())) + } configContent.WriteString(fmt.Sprintf("WgIface: %s\n", g.internalConfig.WgIface)) configContent.WriteString(fmt.Sprintf("WgPort: %d\n", g.internalConfig.WgPort)) if g.internalConfig.NetworkMonitor != nil { From f23aaa9ae7097c3f47e50efe0f418f40a90fd4d7 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 6 May 2026 17:14:11 +0200 Subject: [PATCH 07/27] [client] iOS: structured ResolvedIPs collection for domain routes (#6090) * [client] iOS: structured ResolvedIPs collection for domain routes Replace comma-joined ResolvedIPs string with a gomobile-friendly ResolvedIPs collection (Add/Get/Size), mirroring the Android bridge in client/android/network_domains.go. This allows the iOS app to match domain-route resolved IPs against connected peer routes without parsing CSV strings, fixing the route status indicator for dynamic (DNS) routes. * [client] iOS: align dynamic route exposure with Android bridge For dynamic (DNS) routes the Swift side previously received "invalid Prefix" as the Network value, forcing UI code to special-case that sentinel. The Android bridge uses Domains.SafeString() instead so peer.routes entries (which also derive from Domains.SafeString()) match directly. Mirror that here. Also fix the resolved IP lookup: resolvedDomains is keyed by the resolved domain (e.g. api.ipify.org), not the configured pattern (e.g. *.ipify.org). Group entries by ParentDomain like the daemon does in client/server/network.go, so wildcard route patterns get their resolved IPs populated. --- client/ios/NetBirdSDK/client.go | 43 ++++++++++++++++++++++----------- client/ios/NetBirdSDK/routes.go | 29 +++++++++++++++++++++- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 043673904..a616f9533 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -413,25 +413,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { var routeSelection []RoutesSelectionInfo for _, r := range routes { - domainList := make([]DomainInfo, 0) + // resolvedDomains is keyed by the resolved domain (e.g. api.ipify.org), + // not the configured pattern (e.g. *.ipify.org). Group entries whose + // ParentDomain belongs to this route, mirroring the daemon logic in + // client/server/network.go. + domainList := make([]DomainInfo, 0, len(r.Domains)) + domainIndex := make(map[domain.Domain]int, len(r.Domains)) for _, d := range r.Domains { - domainResp := DomainInfo{ - Domain: d.SafeString(), - } - - if info, exists := resolvedDomains[d]; exists { - var ipStrings []string - for _, prefix := range info.Prefixes { - ipStrings = append(ipStrings, prefix.Addr().String()) - } - domainResp.ResolvedIPs = strings.Join(ipStrings, ", ") - } - domainList = append(domainList, domainResp) + domainIndex[d] = len(domainList) + domainList = append(domainList, DomainInfo{Domain: d.SafeString()}) } + + for _, info := range resolvedDomains { + idx, ok := domainIndex[info.ParentDomain] + if !ok { + continue + } + for _, prefix := range info.Prefixes { + domainList[idx].AddResolvedIP(prefix.Addr().String()) + } + } + domainDetails := DomainDetails{items: domainList} + + // For dynamic (DNS) routes, expose the joined domain pattern as the + // Network value so it matches the peer.routes entries on the Swift + // side (mirroring the Android bridge in client/android/client.go). + netStr := r.Network.String() + if len(r.Domains) > 0 { + netStr = r.Domains.SafeString() + } + routeSelection = append(routeSelection, RoutesSelectionInfo{ ID: r.NetID, - Network: r.Network.String(), + Network: netStr, Domains: &domainDetails, Selected: r.Selected, }) diff --git a/client/ios/NetBirdSDK/routes.go b/client/ios/NetBirdSDK/routes.go index 7b84d6e1c..025313bfa 100644 --- a/client/ios/NetBirdSDK/routes.go +++ b/client/ios/NetBirdSDK/routes.go @@ -34,7 +34,34 @@ type DomainDetails struct { type DomainInfo struct { Domain string - ResolvedIPs string + resolvedIPs ResolvedIPs +} + +func (d *DomainInfo) AddResolvedIP(ipAddress string) { + d.resolvedIPs.Add(ipAddress) +} + +func (d *DomainInfo) GetResolvedIPs() *ResolvedIPs { + return &d.resolvedIPs +} + +type ResolvedIPs struct { + items []string +} + +func (r *ResolvedIPs) Add(ipAddress string) { + r.items = append(r.items, ipAddress) +} + +func (r *ResolvedIPs) Get(i int) string { + if i < 0 || i >= len(r.items) { + return "" + } + return r.items[i] +} + +func (r *ResolvedIPs) Size() int { + return len(r.items) } // Add new PeerInfo to the collection From 205ebcfda28d08ba692d21fb3dbbc0788310c736 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 7 May 2026 18:33:37 +0900 Subject: [PATCH 08/27] [management, client] Add IPv6 overlay support (#5631) --- .gitignore | 1 + client/android/client.go | 96 +- client/android/peer_notifier.go | 1 + client/android/preferences.go | 18 + client/android/route_command.go | 7 +- client/anonymize/anonymize.go | 17 +- client/anonymize/anonymize_test.go | 24 +- client/cmd/ssh.go | 6 +- client/cmd/ssh_test.go | 4 +- client/cmd/status.go | 14 +- client/cmd/system.go | 5 + client/cmd/up.go | 12 + client/embed/embed.go | 3 + client/firewall/iptables/acl_linux.go | 33 +- client/firewall/iptables/manager_linux.go | 275 ++- client/firewall/iptables/router_linux.go | 104 +- client/firewall/iptables/rule.go | 1 + client/firewall/iptables/state_linux.go | 30 + client/firewall/manager/firewall.go | 15 +- client/firewall/manager/routerpair.go | 18 + client/firewall/nftables/acl_linux.go | 44 +- client/firewall/nftables/addr_family_linux.go | 81 + ...al_chain_monitor_integration_linux_test.go | 76 + .../nftables/external_chain_monitor_linux.go | 199 +++ .../external_chain_monitor_linux_test.go | 137 ++ client/firewall/nftables/manager_linux.go | 394 ++++- .../firewall/nftables/manager_linux_test.go | 128 ++ client/firewall/nftables/router_linux.go | 325 ++-- client/firewall/nftables/router_linux_test.go | 189 +- .../uspfilter/allow_netbird_windows.go | 53 +- client/firewall/uspfilter/conntrack/common.go | 7 +- .../uspfilter/conntrack/common_test.go | 48 + client/firewall/uspfilter/conntrack/icmp.go | 92 +- .../firewall/uspfilter/conntrack/icmp_test.go | 36 + client/firewall/uspfilter/filter.go | 325 +++- .../firewall/uspfilter/filter_bench_test.go | 9 +- .../firewall/uspfilter/filter_filter_test.go | 345 +++- .../uspfilter/filter_routeacl_test.go | 18 +- client/firewall/uspfilter/filter_test.go | 75 +- .../firewall/uspfilter/forwarder/endpoint.go | 22 +- .../firewall/uspfilter/forwarder/forwarder.go | 293 +++- .../uspfilter/forwarder/forwarder_test.go | 162 ++ client/firewall/uspfilter/forwarder/icmp.go | 218 ++- client/firewall/uspfilter/forwarder/tcp.go | 18 +- client/firewall/uspfilter/forwarder/udp.go | 17 +- client/firewall/uspfilter/hooks_filter.go | 1 - client/firewall/uspfilter/localip.go | 135 +- .../firewall/uspfilter/localip_bench_test.go | 72 + client/firewall/uspfilter/localip_test.go | 124 +- client/firewall/uspfilter/nat.go | 185 +- client/firewall/uspfilter/nat_bench_test.go | 22 +- client/firewall/uspfilter/nat_test.go | 11 +- client/firewall/uspfilter/tracer.go | 106 +- client/iface/configurer/usp.go | 2 +- client/iface/device/adapter.go | 2 +- client/iface/device/device_android.go | 2 +- client/iface/device/device_darwin.go | 33 +- client/iface/device/device_ios.go | 7 +- client/iface/device/device_kernel_unix.go | 2 +- client/iface/device/device_netstack.go | 9 +- client/iface/device/device_usp_unix.go | 32 +- client/iface/device/device_windows.go | 33 +- client/iface/device/kernel_module.go | 8 - client/iface/device/kernel_module_freebsd.go | 18 - client/iface/device/kernel_module_nonlinux.go | 13 + client/iface/device/wg_link_freebsd.go | 27 +- client/iface/device/wg_link_linux.go | 41 +- client/iface/iface.go | 11 +- .../{iface_new_windows.go => iface_new.go} | 19 +- client/iface/iface_new_android.go | 12 +- client/iface/iface_new_darwin.go | 35 - client/iface/iface_new_freebsd.go | 41 - client/iface/iface_new_ios.go | 10 +- client/iface/iface_new_js.go | 8 +- client/iface/iface_new_linux.go | 46 +- client/iface/iface_test.go | 21 +- client/iface/netstack/tun.go | 8 +- client/iface/wgaddr/address.go | 64 +- client/iface/wgaddr/address_test_helpers.go | 10 + client/iface/wgproxy/bind/proxy.go | 25 +- client/internal/acl/manager.go | 57 +- client/internal/auth/auth.go | 1 + client/internal/connect.go | 18 +- client/internal/connect_android_default.go | 4 + client/internal/debug/debug.go | 31 +- client/internal/debug/debug_test.go | 74 +- client/internal/dns.go | 79 +- client/internal/dns/host_darwin.go | 1 + client/internal/dns/local/local.go | 13 +- client/internal/dns/network_manager_unix.go | 25 +- client/internal/dns/server.go | 2 +- client/internal/dns/server_test.go | 6 +- client/internal/dns/service.go | 4 +- client/internal/dns/service_listener.go | 11 +- client/internal/dns/systemd_linux.go | 6 +- client/internal/dns/upstream.go | 7 + client/internal/dns/upstream_android.go | 2 +- client/internal/dns/upstream_general.go | 2 +- client/internal/dns/upstream_ios.go | 44 +- client/internal/dns_test.go | 138 ++ client/internal/dnsfwd/manager.go | 1 + client/internal/ebpf/ebpf/dns_fwd_linux.go | 15 +- client/internal/ebpf/manager/manager.go | 4 +- client/internal/engine.go | 142 +- client/internal/engine_ssh.go | 40 +- client/internal/engine_test.go | 242 ++- client/internal/iface_common.go | 2 +- .../lazyconn/activity/listener_bind.go | 23 +- client/internal/listener/network_change.go | 1 + .../internal/netflow/conntrack/conntrack.go | 23 +- client/internal/netflow/logger/logger.go | 12 +- client/internal/netflow/logger/logger_test.go | 2 +- client/internal/netflow/manager.go | 7 +- client/internal/netflow/types/types.go | 3 + client/internal/peer/status.go | 10 +- client/internal/peer/status_test.go | 11 +- client/internal/profilemanager/config.go | 12 +- client/internal/relay/relay.go | 3 +- client/internal/rosenpass/manager.go | 9 +- client/internal/rosenpass/manager_test.go | 14 + client/internal/routemanager/client/client.go | 5 +- .../routemanager/client/client_bench_test.go | 2 +- .../routemanager/dnsinterceptor/handler.go | 2 +- client/internal/routemanager/dynamic/route.go | 4 +- .../routemanager/dynamic/route_ios.go | 46 +- client/internal/routemanager/fakeip/fakeip.go | 144 +- .../routemanager/fakeip/fakeip_test.go | 169 +- .../routemanager/ipfwdstate/ipfwdstate.go | 6 +- client/internal/routemanager/manager.go | 18 +- client/internal/routemanager/manager_test.go | 3 +- .../routemanager/notifier/notifier_android.go | 25 +- .../routemanager/notifier/notifier_ios.go | 13 +- .../routemanager/notifier/notifier_other.go | 2 +- client/internal/routemanager/server/server.go | 17 +- .../routemanager/systemops/systemops.go | 10 +- .../systemops/systemops_generic.go | 70 +- .../systemops/systemops_generic_test.go | 3 +- .../routemanager/systemops/systemops_linux.go | 23 +- client/ios/NetBirdSDK/client.go | 73 +- client/ios/NetBirdSDK/peer_notifier.go | 12 + client/ios/NetBirdSDK/preferences.go | 18 + client/proto/daemon.pb.go | 67 +- client/proto/daemon.proto | 6 + client/server/network.go | 41 +- client/server/server.go | 3 + client/server/setconfig_test.go | 5 + client/server/trace.go | 74 +- client/ssh/config/manager.go | 11 +- client/ssh/config/manager_test.go | 13 +- client/ssh/proxy/proxy.go | 2 +- client/ssh/server/port_forwarding.go | 29 +- client/ssh/server/server.go | 68 +- client/status/status.go | 25 + client/status/status_test.go | 11 + client/system/info.go | 4 +- client/ui/client_ui.go | 32 +- client/ui/event/event.go | 10 +- client/ui/network.go | 8 +- client/wasm/cmd/main.go | 122 +- client/wasm/internal/rdp/rdcleanpath.go | 2 +- client/wasm/internal/ssh/client.go | 18 +- combined/cmd/config.go | 2 +- .../service/manager/l4_port_test.go | 5 +- .../reverseproxy/service/manager/manager.go | 5 +- .../service/manager/manager_test.go | 11 +- .../internals/shared/grpc/conversion.go | 91 +- management/internals/shared/grpc/server.go | 12 +- management/server/account.go | 479 ++++- management/server/account/manager.go | 1 + management/server/account/manager_mock.go | 12 + management/server/account_test.go | 217 ++- management/server/activity/codes.go | 7 + management/server/group.go | 71 +- management/server/group_ipv6_test.go | 125 ++ management/server/group_test.go | 9 +- .../handlers/accounts/accounts_handler.go | 82 +- .../accounts/accounts_handler_test.go | 30 + .../http/handlers/dns/nameservers_handler.go | 21 +- .../handlers/dns/nameservers_handler_test.go | 34 + .../handlers/groups/groups_handler_test.go | 6 +- .../http/handlers/peers/peers_handler.go | 38 + .../http/handlers/peers/peers_handler_test.go | 17 +- .../http/testing/testing_tools/tools.go | 4 +- management/server/mock_server/account_mock.go | 8 + management/server/peer.go | 76 +- management/server/peer/peer.go | 56 +- management/server/peer/peer_test.go | 23 + management/server/peer_test.go | 115 +- management/server/policy_test.go | 46 +- management/server/route_test.go | 125 +- management/server/settings/manager.go | 29 + management/server/settings/manager_mock.go | 17 + management/server/store/sql_store.go | 89 +- .../store/sql_store_get_account_test.go | 11 +- management/server/store/sql_store_test.go | 84 +- .../server/store/sqlstore_bench_test.go | 3 +- management/server/store/store.go | 3 +- management/server/store/store_mock.go | 19 +- management/server/types/account.go | 119 +- management/server/types/account_components.go | 17 +- management/server/types/account_test.go | 59 +- management/server/types/firewall_rule.go | 57 +- management/server/types/firewall_rule_test.go | 197 +++ management/server/types/ipv6_endtoend_test.go | 156 ++ management/server/types/ipv6_groups_test.go | 234 +++ management/server/types/network.go | 151 +- management/server/types/network_test.go | 151 +- .../server/types/networkmap_components.go | 166 +- .../networkmap_components_correctness_test.go | 4 +- .../types/networkmap_components_test.go | 8 +- management/server/types/settings.go | 10 + management/server/user.go | 6 + proxy/cmd/proxy/cmd/debug.go | 21 +- proxy/internal/debug/client.go | 22 +- proxy/internal/debug/handler.go | 18 +- relay/test/benchmark_test.go | 2 +- relay/testec2/turn_allocator.go | 2 +- route/route.go | 61 + route/route_test.go | 108 ++ shared/management/client/grpc.go | 14 + shared/management/http/api/openapi.yml | 26 +- shared/management/http/api/types.gen.go | 20 +- shared/management/proto/management.pb.go | 1545 +++++++++-------- shared/management/proto/management.proto | 27 +- shared/netiputil/compact.go | 78 + shared/netiputil/compact_test.go | 175 ++ shared/relay/client/dialer/quic/quic.go | 2 +- upload-server/server/s3_test.go | 3 +- util/capture/text.go | 22 +- 229 files changed, 10155 insertions(+), 2816 deletions(-) create mode 100644 client/firewall/nftables/addr_family_linux.go create mode 100644 client/firewall/nftables/external_chain_monitor_integration_linux_test.go create mode 100644 client/firewall/nftables/external_chain_monitor_linux.go create mode 100644 client/firewall/nftables/external_chain_monitor_linux_test.go create mode 100644 client/firewall/uspfilter/forwarder/forwarder_test.go create mode 100644 client/firewall/uspfilter/localip_bench_test.go delete mode 100644 client/iface/device/kernel_module.go delete mode 100644 client/iface/device/kernel_module_freebsd.go create mode 100644 client/iface/device/kernel_module_nonlinux.go rename client/iface/{iface_new_windows.go => iface_new.go} (50%) delete mode 100644 client/iface/iface_new_darwin.go delete mode 100644 client/iface/iface_new_freebsd.go create mode 100644 client/iface/wgaddr/address_test_helpers.go create mode 100644 client/internal/dns_test.go create mode 100644 client/internal/rosenpass/manager_test.go create mode 100644 management/server/group_ipv6_test.go create mode 100644 management/server/types/firewall_rule_test.go create mode 100644 management/server/types/ipv6_endtoend_test.go create mode 100644 management/server/types/ipv6_groups_test.go create mode 100644 route/route_test.go create mode 100644 shared/netiputil/compact.go create mode 100644 shared/netiputil/compact_test.go diff --git a/.gitignore b/.gitignore index a0f128933..783fe77f3 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ infrastructure_files/setup-*.env vendor/ /netbird client/netbird-electron/ +management/server/types/testdata/ diff --git a/client/android/client.go b/client/android/client.go index 37e17a363..99ccdf393 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -301,10 +301,11 @@ func (c *Client) PeersList() *PeerInfoArray { peerInfos := make([]PeerInfo, len(fullStatus.Peers)) for n, p := range fullStatus.Peers { pi := PeerInfo{ - p.IP, - p.FQDN, - int(p.ConnStatus), - PeerRoutes{routes: maps.Keys(p.GetRoutes())}, + IP: p.IP, + IPv6: p.IPv6, + FQDN: p.FQDN, + ConnStatus: int(p.ConnStatus), + Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())}, } peerInfos[n] = pi } @@ -336,43 +337,84 @@ func (c *Client) Networks() *NetworkArray { return nil } + routesMap := routeManager.GetClientRoutesWithNetID() + v6Merged := route.V6ExitMergeSet(routesMap) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + networkArray := &NetworkArray{ items: make([]Network, 0), } - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - for id, routes := range routeManager.GetClientRoutesWithNetID() { + for id, routes := range routesMap { if len(routes) == 0 { continue } - - r := routes[0] - domains := c.getNetworkDomainsFromRoute(r, resolvedDomains) - netStr := r.Network.String() - - if r.IsDynamic() { - netStr = r.Domains.SafeString() - } - - routePeer, err := c.recorder.GetPeer(routes[0].Peer) - if err != nil { - log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err) + if _, skip := v6Merged[id]; skip { continue } - network := Network{ - Name: string(id), - Network: netStr, - Peer: routePeer.FQDN, - Status: routePeer.ConnStatus.String(), - IsSelected: routeSelector.IsSelected(id), - Domains: domains, + + network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged) + if network == nil { + continue } - networkArray.Add(network) + networkArray.Add(*network) } return networkArray } +func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network { + r := routes[0] + netStr := r.Network.String() + if r.IsDynamic() { + netStr = r.Domains.SafeString() + } + + routePeer, err := c.findBestRoutePeer(routes) + if err != nil { + log.Errorf("could not get peer info for route %s: %v", id, err) + return nil + } + + network := &Network{ + Name: string(id), + Network: netStr, + Peer: routePeer.FQDN, + Status: routePeer.ConnStatus.String(), + IsSelected: selected, + Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains), + } + + if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) { + network.Network = "0.0.0.0/0, ::/0" + } + + return network +} + +// findBestRoutePeer returns the peer actively routing traffic for the given +// HA route group. Falls back to the first connected peer, then the first peer. +func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) { + netStr := routes[0].Network.String() + + fullStatus := c.recorder.GetFullStatus() + for _, p := range fullStatus.Peers { + if _, ok := p.GetRoutes()[netStr]; ok { + return p, nil + } + } + + for _, r := range routes { + p, err := c.recorder.GetPeer(r.Peer) + if err != nil { + continue + } + if p.ConnStatus == peer.StatusConnected { + return p, nil + } + } + return c.recorder.GetPeer(routes[0].Peer) +} + // OnUpdatedHostDNS update the DNS servers addresses for root zones func (c *Client) OnUpdatedHostDNS(list *DNSList) error { dnsServer, err := dns.GetServerDns() diff --git a/client/android/peer_notifier.go b/client/android/peer_notifier.go index 4ec22f3ab..c2595e574 100644 --- a/client/android/peer_notifier.go +++ b/client/android/peer_notifier.go @@ -14,6 +14,7 @@ const ( // PeerInfo describe information about the peers. It designed for the UI usage type PeerInfo struct { IP string + IPv6 string FQDN string ConnStatus int Routes PeerRoutes diff --git a/client/android/preferences.go b/client/android/preferences.go index c3c8eb3fb..066477293 100644 --- a/client/android/preferences.go +++ b/client/android/preferences.go @@ -307,6 +307,24 @@ func (p *Preferences) SetBlockInbound(block bool) { p.configInput.BlockInbound = &block } +// GetDisableIPv6 reads disable IPv6 setting from config file +func (p *Preferences) GetDisableIPv6() (bool, error) { + if p.configInput.DisableIPv6 != nil { + return *p.configInput.DisableIPv6, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableIPv6, err +} + +// SetDisableIPv6 stores the given value and waits for commit +func (p *Preferences) SetDisableIPv6(disable bool) { + p.configInput.DisableIPv6 = &disable +} + // Commit writes out the changes to the config file func (p *Preferences) Commit() error { _, err := profilemanager.UpdateOrCreateConfig(p.configInput) diff --git a/client/android/route_command.go b/client/android/route_command.go index b47d5ca6c..5e7357335 100644 --- a/client/android/route_command.go +++ b/client/android/route_command.go @@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager, netID := route.NetID(id) routes := []route.NetID{netID} - log.Debugf("%s with id: %s", operationName, id) + routesMap := manager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) - if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { + log.Debugf("%s with ids: %v", operationName, routes) + + if err := routeOperation(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when %s: %s", operationName, err) return fmt.Errorf("error %s: %w", operationName, err) } diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 89e653300..c140cef89 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -9,6 +9,7 @@ import ( "net/url" "regexp" "slices" + "strconv" "strings" ) @@ -26,8 +27,9 @@ type Anonymizer struct { } func DefaultAddresses() (netip.Addr, netip.Addr) { - // 198.51.100.0, 100:: - return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) + // 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48) + // The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android. + return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::") } func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { @@ -48,7 +50,7 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr { ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() || - ip.IsPrivate() || + (ip.Is4() && ip.IsPrivate()) || ip.IsUnspecified() || ip.IsMulticast() || isWellKnown(ip) || @@ -96,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { } func (a *Anonymizer) AnonymizeIPString(ip string) string { + // Handle CIDR notation (e.g. "2001:db8::/32") + if prefix, err := netip.ParsePrefix(ip); err == nil { + return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits()) + } + addr, err := netip.ParseAddr(ip) if err != nil { return ip @@ -150,7 +157,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string { if u.Opaque != "" { host, port, err := net.SplitHostPort(u.Opaque) if err == nil { - anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port) + anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port) } else { anonymizedHost = a.AnonymizeDomain(u.Opaque) } @@ -158,7 +165,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string { } else if u.Host != "" { host, port, err := net.SplitHostPort(u.Host) if err == nil { - anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port) + anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port) } else { anonymizedHost = a.AnonymizeDomain(u.Host) } diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index ff2e48869..852315fa1 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -13,7 +13,7 @@ import ( func TestAnonymizeIP(t *testing.T) { startIPv4 := netip.MustParseAddr("198.51.100.0") - startIPv6 := netip.MustParseAddr("100::") + startIPv6 := netip.MustParseAddr("2001:db8:ffff::") anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6) tests := []struct { @@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) { {"Second Public IPv4", "4.3.2.1", "198.51.100.1"}, {"Repeated IPv4", "1.2.3.4", "198.51.100.0"}, {"Private IPv4", "192.168.1.1", "192.168.1.1"}, - {"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"}, - {"Second Public IPv6", "a::b", "100::1"}, - {"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"}, + {"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, + {"Second Public IPv6", "a::b", "2001:db8:ffff::1"}, + {"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"}, {"Private IPv6", "fe80::1", "fe80::1"}, {"In Range IPv4", "198.51.100.2", "198.51.100.2"}, } @@ -274,17 +274,27 @@ func TestAnonymizeString_IPAddresses(t *testing.T) { { name: "IPv6 Address", input: "Access attempted from 2001:db8::ff00:42", - expect: "Access attempted from 100::", + expect: "Access attempted from 2001:db8:ffff::", }, { name: "IPv6 Address with Port", input: "Access attempted from [2001:db8::ff00:42]:8080", - expect: "Access attempted from [100::]:8080", + expect: "Access attempted from [2001:db8:ffff::]:8080", }, { name: "Both IPv4 and IPv6", input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43", - expect: "IPv4: 198.51.100.1 and IPv6: 100::1", + expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1", + }, + { + name: "STUN URI with IPv6", + input: "Connecting to stun:[2001:db8::ff00:42]:3478", + expect: "Connecting to stun:[2001:db8:ffff::]:3478", + }, + { + name: "HTTPS URI with IPv6", + input: "Visit https://[2001:db8::ff00:42]:443/path", + expect: "Visit https://[2001:db8:ffff::]:443/path", }, } diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 0acf0b133..d6e052e08 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error { } func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { - target := fmt.Sprintf("%s:%d", addr, port) + target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port)) c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{ KnownHostsFile: knownHostsFile, IdentityFile: identityFile, @@ -787,10 +787,10 @@ func isUnixSocket(path string) bool { return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") } -// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. +// normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack). func normalizeLocalHost(host string) string { if host == "*" { - return "0.0.0.0" + return "" } return host } diff --git a/client/cmd/ssh_test.go b/client/cmd/ssh_test.go index 43291fa87..16ffadb90 100644 --- a/client/cmd/ssh_test.go +++ b/client/cmd/ssh_test.go @@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) { { name: "wildcard bind all interfaces", spec: "*:8080:localhost:80", - expectedLocal: "0.0.0.0:8080", + expectedLocal: ":8080", expectedRemote: "localhost:80", expectError: false, - description: "Wildcard * should bind to all interfaces (0.0.0.0)", + description: "Wildcard * should bind to all interfaces (dual-stack)", }, { name: "wildcard for port only", diff --git a/client/cmd/status.go b/client/cmd/status.go index c35a06eb3..dae30e854 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -20,6 +20,7 @@ import ( var ( detailFlag bool ipv4Flag bool + ipv6Flag bool jsonFlag bool yamlFlag bool ipsFilter []string @@ -45,8 +46,9 @@ func init() { statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format") statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format") statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") - statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") - statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") + statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer") + statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6") + statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") @@ -101,6 +103,14 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } + if ipv6Flag { + ipv6 := resp.GetFullStatus().GetLocalPeerState().GetIpv6() + if ipv6 != "" { + cmd.Print(parseInterfaceIP(ipv6)) + } + return nil + } + pm := profilemanager.NewProfileManager() var profName string if activeProf, err := pm.GetActiveProfile(); err == nil { diff --git a/client/cmd/system.go b/client/cmd/system.go index f63432401..b386fe4ae 100644 --- a/client/cmd/system.go +++ b/client/cmd/system.go @@ -8,6 +8,7 @@ const ( disableFirewallFlag = "disable-firewall" blockLANAccessFlag = "block-lan-access" blockInboundFlag = "block-inbound" + disableIPv6Flag = "disable-ipv6" ) var ( @@ -17,6 +18,7 @@ var ( disableFirewall bool blockLANAccess bool blockInbound bool + disableIPv6 bool ) func init() { @@ -39,4 +41,7 @@ func init() { upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ "This overrides any policies received from the management service.") + + upCmd.PersistentFlags().BoolVar(&disableIPv6, disableIPv6Flag, false, + "Disable IPv6 overlay. If enabled, the client won't request or use an IPv6 overlay address.") } diff --git a/client/cmd/up.go b/client/cmd/up.go index f4136cb23..cabd0aacf 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -435,6 +435,10 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro req.BlockInbound = &blockInbound } + if cmd.Flag(disableIPv6Flag).Changed { + req.DisableIpv6 = &disableIPv6 + } + if cmd.Flag(enableLazyConnectionFlag).Changed { req.LazyConnectionEnabled = &lazyConnEnabled } @@ -552,6 +556,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil ic.BlockInbound = &blockInbound } + if cmd.Flag(disableIPv6Flag).Changed { + ic.DisableIPv6 = &disableIPv6 + } + if cmd.Flag(enableLazyConnectionFlag).Changed { ic.LazyConnectionEnabled = &lazyConnEnabled } @@ -666,6 +674,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte loginRequest.BlockInbound = &blockInbound } + if cmd.Flag(disableIPv6Flag).Changed { + loginRequest.DisableIpv6 = &disableIPv6 + } + if cmd.Flag(enableLazyConnectionFlag).Changed { loginRequest.LazyConnectionEnabled = &lazyConnEnabled } diff --git a/client/embed/embed.go b/client/embed/embed.go index baa1d94d6..4b9445b97 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -80,6 +80,8 @@ type Options struct { StatePath string // DisableClientRoutes disables the client routes DisableClientRoutes bool + // DisableIPv6 disables IPv6 overlay addressing + DisableIPv6 bool // BlockInbound blocks all inbound connections from peers BlockInbound bool // WireguardPort is the port for the tunnel interface. Use 0 for a random port. @@ -171,6 +173,7 @@ func New(opts Options) (*Client, error) { PreSharedKey: &opts.PreSharedKey, DisableServerRoutes: &t, DisableClientRoutes: &opts.DisableClientRoutes, + DisableIPv6: &opts.DisableIPv6, BlockInbound: &opts.BlockInbound, WireguardPort: opts.WireguardPort, MTU: opts.MTU, diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index e629f7881..e5e19cec9 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -40,6 +40,7 @@ type aclManager struct { entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore + v6 bool stateManager *statemanager.Manager } @@ -51,6 +52,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl entries: make(map[string][][]string), optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(), + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, }, nil } @@ -85,7 +87,11 @@ func (m *aclManager) AddPeerFiltering( chain := chainNameInputRules ipsetName = transformIPsetName(ipsetName, sPort, dPort, action) - specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) + if m.v6 && ipsetName != "" { + ipsetName += "-v6" + } + proto := protoForFamily(protocol, m.v6) + specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName) mangleSpecs := slices.Clone(specs) mangleSpecs = append(mangleSpecs, @@ -109,6 +115,7 @@ func (m *aclManager) AddPeerFiltering( ip: ip.String(), chain: chain, specs: specs, + v6: m.v6, }}, nil } @@ -161,6 +168,7 @@ func (m *aclManager) AddPeerFiltering( ipsetName: ipsetName, ip: ip.String(), chain: chain, + v6: m.v6, } m.updateState() @@ -413,8 +421,13 @@ func (m *aclManager) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.ACLEntries = m.entries - currentState.ACLIPsetStore = m.ipsetStore + if m.v6 { + currentState.ACLEntries6 = m.entries + currentState.ACLIPsetStore6 = m.ipsetStore + } else { + currentState.ACLEntries = m.entries + currentState.ACLIPsetStore = m.ipsetStore + } if err := m.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -422,13 +435,22 @@ func (m *aclManager) updateState() { } // filterRuleSpecs returns the specs of a filtering rule +// protoForFamily translates ICMP to ICMPv6 for ip6tables. +// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp". +func protoForFamily(protocol firewall.Protocol, v6 bool) string { + if v6 && protocol == firewall.ProtocolICMP { + return "ipv6-icmp" + } + return string(protocol) +} + func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { // don't use IP matching if IP is 0.0.0.0 matchByIP := !ip.IsUnspecified() if matchByIP { if ipsetName != "" { - specs = append(specs, "-m", "set", "--set", ipsetName, "src") + specs = append(specs, "-m", "set", "--match-set", ipsetName, "src") } else { specs = append(specs, "-s", ip.String()) } @@ -474,6 +496,9 @@ func (m *aclManager) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if m.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 7d8cd7f8c..696537dd8 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -18,6 +18,10 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) +type resetter interface { + Reset() error +} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex @@ -28,6 +32,11 @@ type Manager struct { aclMgr *aclManager router *router rawSupported bool + + // IPv6 counterparts, nil when no v6 overlay + ipv6Client *iptables.IPTables + aclMgr6 *aclManager + router6 *router } // iFaceMapper defines subset methods of interface required for manager @@ -58,9 +67,43 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + return m, nil } +func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error { + ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return fmt.Errorf("init ip6tables: %w", err) + } + m.ipv6Client = ip6Client + + m.router6, err = newRouter(ip6Client, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + + m.aclMgr6, err = newAclManager(ip6Client, wgIface) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +func (m *Manager) hasIPv6() bool { + return m.ipv6Client != nil +} + func (m *Manager) Init(stateManager *statemanager.Manager) error { state := &ShutdownState{ InterfaceState: &InterfaceState{ @@ -74,13 +117,8 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - if err := m.router.init(stateManager); err != nil { - return fmt.Errorf("router init: %w", err) - } - - if err := m.aclMgr.init(stateManager); err != nil { - // TODO: cleanup router - return fmt.Errorf("acl manager init: %w", err) + if err := m.initChains(stateManager); err != nil { + return err } if err := m.initNoTrackChain(); err != nil { @@ -103,6 +141,41 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return nil } +// initChains initializes router and ACL chains for both address families, +// rolling back on failure. +func (m *Manager) initChains(stateManager *statemanager.Manager) error { + type initStep struct { + name string + init func(*statemanager.Manager) error + mgr resetter + } + + steps := []initStep{ + {"router", m.router.init, m.router}, + {"acl manager", m.aclMgr.init, m.aclMgr}, + } + if m.hasIPv6() { + steps = append(steps, + initStep{"v6 router", m.router6.init, m.router6}, + initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6}, + ) + } + + var initialized []initStep + for _, s := range steps { + if err := s.init(stateManager); err != nil { + for i := len(initialized) - 1; i >= 0; i-- { + if rerr := initialized[i].mgr.Reset(); rerr != nil { + log.Warnf("rollback %s: %v", initialized[i].name, rerr) + } + } + return fmt.Errorf("%s init: %w", s.name, err) + } + initialized = append(initialized, s) + } + return nil +} + // AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported @@ -118,7 +191,13 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if ip.To4() != nil { + return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + } + if !m.hasIPv6() { + return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized) + } + return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -132,25 +211,48 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + // DeletePeerRule from the firewall by rule definition func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6IptRule(rule) { + return m.aclMgr6.DeletePeerRule(rule) + } return m.aclMgr.DeletePeerRule(rule) } +func isIPv6IptRule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.v6 +} + +// DeleteRouteRule deletes a routing rule. +// Route rules are keyed by content hash. Check v4 first, try v6 if not found. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()) { + return m.router6.DeleteRouteRule(rule) + } return m.router.DeleteRouteRule(rule) } @@ -166,18 +268,65 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes need NAT in both tables since resolved IPs can be + // either v4 or v6. This covers both DomainSet (modern) and the legacy + // wildcard 0.0.0.0/0 destination where the client resolves DNS. + if m.hasIPv6() && pair.Dynamic { + v6Pair := firewall.ToV6NatPair(pair) + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + var merr *multierror.Error + + if err := m.router.RemoveNatRule(pair); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err)) + } + + if m.hasIPv6() && pair.Dynamic { + v6Pair := firewall.ToV6NatPair(pair) + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Reset firewall to the default state @@ -191,6 +340,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) } + if m.hasIPv6() { + if err := m.aclMgr6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err)) + } + if err := m.router6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err)) + } + } + if err := m.aclMgr.Reset(); err != nil { merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) } @@ -218,24 +376,21 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { // This is called when USPFilter wraps the native firewall, adding blanket accept // rules so that packet filtering is handled in userspace instead of by netfilter. func (m *Manager) AllowNetbird() error { - _, err := m.AddPeerFiltering( - nil, - net.IP{0, 0, 0, 0}, - firewall.ProtocolALL, - nil, - nil, - firewall.ActionAccept, - "", - ) - if err != nil { - return fmt.Errorf("allow netbird interface traffic: %w", err) + var merr *multierror.Error + if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err)) + } + if m.hasIPv6() { + if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil { + merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err)) + } } if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { log.Warnf("failed to trust interface in firewalld: %v", err) } - return nil + return nberrors.FormatErrorOrNil(merr) } // Flush doesn't need to be implemented for this manager @@ -265,6 +420,12 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if rule.TranslatedAddress.Is6() { + if !m.hasIPv6() { + return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -273,6 +434,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) { + return m.router6.DeleteDNATRule(rule) + } return m.router.DeleteDNATRule(rule) } @@ -281,39 +445,82 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. -func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveInboundDNAT removes an inbound DNAT rule. -func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort) } // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. -func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. -func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) } const ( diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index a7c4f67dd..290e5da1e 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -54,8 +54,10 @@ const ( snatSuffix = "_snat" fwdSuffix = "_fwd" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 ) type ruleInfo struct { @@ -86,6 +88,7 @@ type router struct { wgIface iFaceMapper legacyManagement bool mtu uint16 + v6 bool stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState @@ -97,6 +100,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1 rules: make(map[string][]string), wgIface: wgIface, mtu: mtu, + v6: iptablesClient.Proto() == iptables.ProtocolIPv6, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -186,6 +190,11 @@ func (r *router) AddRouteFiltering( return ruleKey, nil } +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + func (r *router) DeleteRouteRule(rule firewall.Rule) error { ruleKey := rule.ID() @@ -392,9 +401,13 @@ func (r *router) cleanUpDefaultForwardRules() error { // Remove jump rules from built-in chains before deleting custom chains, // otherwise the chain deletion fails with "device or resource busy". - jumpRule := []string{"-j", chainNATOutput} - if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil { - log.Debugf("clean OUTPUT jump rule: %v", err) + if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil { + return fmt.Errorf("check chain %s: %w", chainNATOutput, err) + } else if ok { + jumpRule := []string{"-j", chainNATOutput} + if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil { + log.Debugf("clean OUTPUT jump rule: %v", err) + } } for _, chainInfo := range []struct { @@ -434,6 +447,12 @@ func (r *router) createContainers() error { {chainRTRDR, tableNat}, {chainRTMSSCLAMP, tableMangle}, } { + // Fallback: clear chains that survived an unclean shutdown. + if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok { + if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil { + log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err) + } + } if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } @@ -540,9 +559,12 @@ func (r *router) addPostroutingRules() error { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.v6 { + overhead = ipv6TCPHeaderSize + } + mss := r.mtu - overhead // Add jump rule from FORWARD chain in mangle table to our custom chain jumpRule := []string{ @@ -727,8 +749,13 @@ func (r *router) updateState() { currentState.Lock() defer currentState.Unlock() - currentState.RouteRules = r.rules - currentState.RouteIPsetCounter = r.ipsetCounter + if r.v6 { + currentState.RouteRules6 = r.rules + currentState.RouteIPsetCounter6 = r.ipsetCounter + } else { + currentState.RouteRules = r.rules + currentState.RouteIPsetCounter = r.ipsetCounter + } if err := r.stateManager.UpdateState(currentState); err != nil { log.Errorf("failed to update state: %v", err) @@ -856,7 +883,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error { } if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists { - if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil { + if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil { merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err)) } delete(r.rules, ruleKey+fwdSuffix) @@ -883,7 +910,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net rule = append(rule, destExp...) if params.Proto != firewall.ProtocolALL { - rule = append(rule, "-p", strings.ToLower(string(params.Proto))) + rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6))) rule = append(rule, applyPort("--sport", params.SPort)...) rule = append(rule, applyPort("--dport", params.DPort)...) } @@ -900,11 +927,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } if network.IsSet() { - if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { + name := r.ipsetName(network.Set.HashedName()) + if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil { return nil, fmt.Errorf("create or get ipset: %w", err) } - return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil + return []string{"-m", "set", matchSet, name, direction}, nil } if network.IsPrefix() { return []string{flag, network.Prefix.String()}, nil @@ -915,27 +943,23 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes [] } func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + name := r.ipsetName(set.HashedName()) var merr *multierror.Error for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil { + if err := r.addPrefixToIPSet(name, prefix); err != nil { merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) } } if merr == nil { - log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) + log.Debugf("updated set %s with prefixes %v", name, prefixes) } return nberrors.FormatErrorOrNil(merr) } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. -func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if _, exists := r.rules[ruleID]; exists { return nil @@ -943,12 +967,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol dnatRule := []string{ "-i", r.wgIface.Name(), - "-p", strings.ToLower(string(protocol)), - "--dport", strconv.Itoa(int(sourcePort)), + "-p", strings.ToLower(protoForFamily(protocol, r.v6)), + "--dport", strconv.Itoa(int(originalPort)), "-d", localAddr.String(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "DNAT", - "--to-destination", ":" + strconv.Itoa(int(targetPort)), + "--to-destination", ":" + strconv.Itoa(int(translatedPort)), } ruleInfo := ruleInfo{ @@ -967,8 +991,8 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol } // RemoveInboundDNAT removes an inbound DNAT rule. -func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if dnatRule, exists := r.rules[ruleID]; exists { if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { @@ -1013,8 +1037,8 @@ func (r *router) ensureNATOutputChain() error { } // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. -func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if _, exists := r.rules[ruleID]; exists { return nil @@ -1025,11 +1049,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, } dnatRule := []string{ - "-p", strings.ToLower(string(protocol)), - "--dport", strconv.Itoa(int(sourcePort)), + "-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())), + "--dport", strconv.Itoa(int(originalPort)), "-d", localAddr.String(), "-j", "DNAT", - "--to-destination", ":" + strconv.Itoa(int(targetPort)), + "--to-destination", ":" + strconv.Itoa(int(translatedPort)), } if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil { @@ -1042,8 +1066,8 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, } // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. -func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if dnatRule, exists := r.rules[ruleID]; exists { if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil { @@ -1076,10 +1100,22 @@ func applyPort(flag string, port *firewall.Port) []string { return []string{flag, strconv.Itoa(int(port.Values[0]))} } +// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router +// to avoid collisions since ipsets are global in the kernel. +func (r *router) ipsetName(name string) string { + if r.v6 { + return name + "-v6" + } + return name +} + func (r *router) createIPSet(name string) error { opts := ipset.CreateOptions{ Replace: true, } + if r.v6 { + opts.Family = ipset.FamilyIPV6 + } if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { return fmt.Errorf("create ipset %s: %w", name, err) diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go index aa4d2d079..4f4eab167 100644 --- a/client/firewall/iptables/rule.go +++ b/client/firewall/iptables/rule.go @@ -9,6 +9,7 @@ type Rule struct { mangleSpecs []string ip string chain string + v6 bool } // GetRuleID returns the rule id diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index 121c755e9..f4be37d01 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,8 @@ import ( "fmt" "sync" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -32,6 +34,12 @@ type ShutdownState struct { ACLEntries aclEntries `json:"acl_entries,omitempty"` ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` + + // IPv6 counterparts + RouteRules6 routeRules `json:"route_rules_v6,omitempty"` + RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"` + ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"` + ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"` } func (s *ShutdownState) Name() string { @@ -62,6 +70,28 @@ func (s *ShutdownState) Cleanup() error { ipt.aclMgr.ipsetStore = s.ACLIPsetStore } + // Clean up v6 state even if the current run has no IPv6. + // The previous run may have left ip6tables rules behind. + if !ipt.hasIPv6() { + if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil { + log.Warnf("failed to create v6 components for cleanup: %v", err) + } + } + if ipt.hasIPv6() { + if s.RouteRules6 != nil { + ipt.router6.rules = s.RouteRules6 + } + if s.RouteIPsetCounter6 != nil { + ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6) + } + if s.ACLEntries6 != nil { + ipt.aclMgr6.entries = s.ACLEntries6 + } + if s.ACLIPsetStore6 != nil { + ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6 + } + } + if err := ipt.Close(nil); err != nil { return fmt.Errorf("reset iptables manager: %w", err) } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index d65d717b3..149c6db83 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,6 +1,7 @@ package manager import ( + "errors" "fmt" "net" "net/netip" @@ -11,6 +12,10 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) +// ErrIPv6NotInitialized is returned when an IPv6 address is passed to a firewall +// method but the IPv6 firewall components were not initialized. +var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized") + const ( ForwardingFormatPrefix = "netbird-fwd-" ForwardingFormat = "netbird-fwd-%s-%t" @@ -164,18 +169,16 @@ type Manager interface { UpdateSet(hash Set, prefixes []netip.Prefix) error // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services - AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + AddInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error // RemoveInboundDNAT removes inbound DNAT rule - RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. - // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. - AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + AddOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. - // localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. - RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. // This prevents conntrack from interfering with WireGuard proxy communication. diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go index 079c051d9..096f8b9bb 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/manager/routerpair.go @@ -1,6 +1,8 @@ package manager import ( + "net/netip" + "github.com/netbirdio/netbird/route" ) @@ -10,6 +12,10 @@ type RouterPair struct { Destination Network Masquerade bool Inverse bool + // Dynamic indicates the route is domain-based. NAT rules for dynamic + // routes are duplicated to the v6 table so that resolved AAAA records + // are masqueraded correctly. + Dynamic bool } func GetInversePair(pair RouterPair) RouterPair { @@ -20,5 +26,17 @@ func GetInversePair(pair RouterPair) RouterPair { Destination: pair.Source, Masquerade: pair.Masquerade, Inverse: true, + Dynamic: pair.Dynamic, } } + +// ToV6NatPair creates a v6 counterpart of a v4 NAT pair with `::/0` source +// and, for prefix destinations, `::/0` destination. +func ToV6NatPair(pair RouterPair) RouterPair { + v6 := pair + v6.Source = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + if v6.Destination.IsPrefix() { + v6.Destination = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)} + } + return v6 +} diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index a9d066e2f..9d2ea7264 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -33,15 +33,12 @@ const ( const flushError = "flush: %w" -var ( - anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} -) - type AclManager struct { rConn *nftables.Conn sConn *nftables.Conn wgIface iFaceMapper routingFwChainName string + af addrFamily workTable *nftables.Table chainInputRules *nftables.Chain @@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam wgIface: wgIface, workTable: table, routingFwChainName: routingFwChainName, + af: familyForAddr(table.Family == nftables.TableFamilyIPv4), ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { } if _, ok := ips[r.ip.String()]; ok { - err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) + err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}}) if err != nil { log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err) } @@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering( expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), + Offset: m.af.protoOffset, Len: uint32(1), }) - protoData, err := protoToInt(proto) + protoData, err := m.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %v", err) } @@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering( }) } - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) // check if rawIP contains zeroed IPv4 0.0.0.0 value // in that case not add IP match expression into the rule definition - if !bytes.HasPrefix(anyIP, rawIP) { - // source address position - addrOffset := uint32(12) - + if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) { expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: addrOffset, - Len: 4, + Offset: m.af.srcAddrOffset, + Len: m.af.addrLen, }, ) // add individual IP for match if no ipset defined @@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) { ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName) - rawIP := ip.To4() + rawIP := ipToBytes(ip, m.af) if err != nil { if ipset, err = m.createSet(m.workTable, ipsetName); err != nil { return nil, fmt.Errorf("get set name: %v", err) @@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se Name: name, Table: table, Dynamic: true, - KeyType: nftables.TypeIPAddr, + KeyType: m.af.setKeyType, } if err := m.rConn.AddSet(ipset, nil); err != nil { @@ -707,15 +702,12 @@ func ifname(n string) []byte { return b } -func protoToInt(protocol firewall.Protocol) (uint8, error) { - switch protocol { - case firewall.ProtocolTCP: - return unix.IPPROTO_TCP, nil - case firewall.ProtocolUDP: - return unix.IPPROTO_UDP, nil - case firewall.ProtocolICMP: - return unix.IPPROTO_ICMP, nil - } - return 0, fmt.Errorf("unsupported protocol: %s", protocol) +// ipToBytes converts net.IP to the correct byte length for the address family. +func ipToBytes(ip net.IP, af addrFamily) []byte { + if af.addrLen == 4 { + return ip.To4() + } + return ip.To16() } + diff --git a/client/firewall/nftables/addr_family_linux.go b/client/firewall/nftables/addr_family_linux.go new file mode 100644 index 000000000..0c90d704a --- /dev/null +++ b/client/firewall/nftables/addr_family_linux.go @@ -0,0 +1,81 @@ +package nftables + +import ( + "fmt" + "net" + + "github.com/google/nftables" + "golang.org/x/sys/unix" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +var ( + // afIPv4 defines IPv4 header layout and nftables types. + afIPv4 = addrFamily{ + protoOffset: 9, + srcAddrOffset: 12, + dstAddrOffset: 16, + addrLen: net.IPv4len, + totalBits: 8 * net.IPv4len, + setKeyType: nftables.TypeIPAddr, + tableFamily: nftables.TableFamilyIPv4, + icmpProto: unix.IPPROTO_ICMP, + } + // afIPv6 defines IPv6 header layout and nftables types. + afIPv6 = addrFamily{ + protoOffset: 6, + srcAddrOffset: 8, + dstAddrOffset: 24, + addrLen: net.IPv6len, + totalBits: 8 * net.IPv6len, + setKeyType: nftables.TypeIP6Addr, + tableFamily: nftables.TableFamilyIPv6, + icmpProto: unix.IPPROTO_ICMPV6, + } +) + +// addrFamily holds protocol-specific constants for nftables expression building. +type addrFamily struct { + // protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6) + protoOffset uint32 + // srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6) + srcAddrOffset uint32 + // dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6) + dstAddrOffset uint32 + // addrLen is the byte length of addresses (4 for v4, 16 for v6) + addrLen uint32 + // totalBits is the address size in bits (32 for v4, 128 for v6) + totalBits int + // setKeyType is the nftables set data type for addresses + setKeyType nftables.SetDatatype + // tableFamily is the nftables table family + tableFamily nftables.TableFamily + // icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6) + icmpProto uint8 +} + +// familyForAddr returns the address family for the given IP. +func familyForAddr(is4 bool) addrFamily { + if is4 { + return afIPv4 + } + return afIPv6 +} + +// protoNum converts a firewall protocol to the IP protocol number, +// using the correct ICMP variant for the address family. +func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) { + switch protocol { + case firewall.ProtocolTCP: + return unix.IPPROTO_TCP, nil + case firewall.ProtocolUDP: + return unix.IPPROTO_UDP, nil + case firewall.ProtocolICMP: + return af.icmpProto, nil + case firewall.ProtocolALL: + return 0, nil + default: + return 0, fmt.Errorf("unsupported protocol: %s", protocol) + } +} diff --git a/client/firewall/nftables/external_chain_monitor_integration_linux_test.go b/client/firewall/nftables/external_chain_monitor_integration_linux_test.go new file mode 100644 index 000000000..3c4e3f44d --- /dev/null +++ b/client/firewall/nftables/external_chain_monitor_integration_linux_test.go @@ -0,0 +1,76 @@ +//go:build linux + +package nftables + +import ( + "os" + "sync/atomic" + "testing" + "time" + + "github.com/google/nftables" + "github.com/stretchr/testify/require" +) + +// TestExternalChainMonitorRootIntegration verifies that adding a new chain +// in an external (non-netbird) filter table triggers the reconciler. +// Requires CAP_NET_ADMIN; skip otherwise. +func TestExternalChainMonitorRootIntegration(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip("root required") + } + + calls := make(chan struct{}, 8) + var count atomic.Int32 + rec := &countingReconciler{calls: calls, count: &count} + + m := newExternalChainMonitor(rec) + m.start() + t.Cleanup(m.stop) + + // Give the netlink subscription a moment to register. + time.Sleep(200 * time.Millisecond) + + conn := &nftables.Conn{} + table := conn.AddTable(&nftables.Table{ + Name: "nbmon_integration_test", + Family: nftables.TableFamilyINet, + }) + t.Cleanup(func() { + cleanup := &nftables.Conn{} + cleanup.DelTable(table) + _ = cleanup.Flush() + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "filter_INPUT", + Table: table, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + Type: nftables.ChainTypeFilter, + }) + _ = chain + require.NoError(t, conn.Flush(), "create external test chain") + + select { + case <-calls: + // success + case <-time.After(3 * time.Second): + t.Fatalf("reconcile was not invoked after creating an external chain") + } + require.GreaterOrEqual(t, count.Load(), int32(1)) +} + +type countingReconciler struct { + calls chan struct{} + count *atomic.Int32 +} + +func (c *countingReconciler) reconcileExternalChains() error { + c.count.Add(1) + select { + case c.calls <- struct{}{}: + default: + } + return nil +} diff --git a/client/firewall/nftables/external_chain_monitor_linux.go b/client/firewall/nftables/external_chain_monitor_linux.go new file mode 100644 index 000000000..2a2e04c09 --- /dev/null +++ b/client/firewall/nftables/external_chain_monitor_linux.go @@ -0,0 +1,199 @@ +package nftables + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/google/nftables" + log "github.com/sirupsen/logrus" +) + +const ( + externalMonitorReconcileDelay = 500 * time.Millisecond + externalMonitorInitInterval = 5 * time.Second + externalMonitorMaxInterval = 5 * time.Minute + externalMonitorRandomization = 0.5 +) + +// externalChainReconciler re-applies passthrough accept rules to external +// nftables chains. Implementations must be safe to call from the monitor +// goroutine; the Manager locks its mutex internally. +type externalChainReconciler interface { + reconcileExternalChains() error +} + +// externalChainMonitor watches nftables netlink events and triggers a +// reconcile when a new table or chain appears (e.g. after +// `firewall-cmd --reload`). Netlink errors trigger exponential-backoff +// reconnect. +type externalChainMonitor struct { + reconciler externalChainReconciler + + mu sync.Mutex + cancel context.CancelFunc + done chan struct{} +} + +func newExternalChainMonitor(r externalChainReconciler) *externalChainMonitor { + return &externalChainMonitor{reconciler: r} +} + +func (m *externalChainMonitor) start() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = cancel + m.done = make(chan struct{}) + + go m.run(ctx) +} + +func (m *externalChainMonitor) stop() { + m.mu.Lock() + cancel := m.cancel + done := m.done + m.cancel = nil + m.done = nil + m.mu.Unlock() + + if cancel == nil { + return + } + cancel() + <-done +} + +func (m *externalChainMonitor) run(ctx context.Context) { + defer close(m.done) + + bo := &backoff.ExponentialBackOff{ + InitialInterval: externalMonitorInitInterval, + RandomizationFactor: externalMonitorRandomization, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: externalMonitorMaxInterval, + MaxElapsedTime: 0, + Clock: backoff.SystemClock, + } + bo.Reset() + + for ctx.Err() == nil { + err := m.watch(ctx) + if ctx.Err() != nil { + return + } + + delay := bo.NextBackOff() + log.Warnf("external chain monitor: %v, reconnecting in %s", err, delay) + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + } +} + +func (m *externalChainMonitor) watch(ctx context.Context) error { + events, closeMon, err := m.subscribe() + if err != nil { + return err + } + defer closeMon() + + debounce := time.NewTimer(time.Hour) + if !debounce.Stop() { + <-debounce.C + } + defer debounce.Stop() + + pending := false + for { + select { + case <-ctx.Done(): + return nil + case <-debounce.C: + pending = false + m.reconcile() + case ev, ok := <-events: + if !ok { + return errors.New("monitor channel closed") + } + if ev.Error != nil { + return fmt.Errorf("monitor event: %w", ev.Error) + } + if !isRelevantMonitorEvent(ev) { + continue + } + resetDebounce(debounce, pending) + pending = true + } + } +} + +func (m *externalChainMonitor) subscribe() (chan *nftables.MonitorEvent, func(), error) { + conn := &nftables.Conn{} + mon := nftables.NewMonitor( + nftables.WithMonitorAction(nftables.MonitorActionNew), + nftables.WithMonitorObject(nftables.MonitorObjectChains|nftables.MonitorObjectTables), + ) + events, err := conn.AddMonitor(mon) + if err != nil { + return nil, nil, fmt.Errorf("add netlink monitor: %w", err) + } + return events, func() { _ = mon.Close() }, nil +} + +// resetDebounce reschedules a pending debounce timer without leaking a stale +// fire on its channel. pending must reflect whether the timer is armed. +func resetDebounce(t *time.Timer, pending bool) { + if pending && !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(externalMonitorReconcileDelay) +} + +func (m *externalChainMonitor) reconcile() { + if err := m.reconciler.reconcileExternalChains(); err != nil { + log.Warnf("reconcile external chain rules: %v", err) + } +} + +// isRelevantMonitorEvent returns true for table/chain creation events on +// families we care about. The reconciler filters to actual external filter +// chains. +func isRelevantMonitorEvent(ev *nftables.MonitorEvent) bool { + switch ev.Type { + case nftables.MonitorEventTypeNewChain: + chain, ok := ev.Data.(*nftables.Chain) + if !ok || chain == nil || chain.Table == nil { + return false + } + return isMonitoredFamily(chain.Table.Family) + case nftables.MonitorEventTypeNewTable: + table, ok := ev.Data.(*nftables.Table) + if !ok || table == nil { + return false + } + return isMonitoredFamily(table.Family) + } + return false +} + +func isMonitoredFamily(family nftables.TableFamily) bool { + switch family { + case nftables.TableFamilyIPv4, nftables.TableFamilyIPv6, nftables.TableFamilyINet: + return true + } + return false +} diff --git a/client/firewall/nftables/external_chain_monitor_linux_test.go b/client/firewall/nftables/external_chain_monitor_linux_test.go new file mode 100644 index 000000000..1a37faca2 --- /dev/null +++ b/client/firewall/nftables/external_chain_monitor_linux_test.go @@ -0,0 +1,137 @@ +package nftables + +import ( + "testing" + + "github.com/google/nftables" + "github.com/stretchr/testify/assert" +) + +func TestIsMonitoredFamily(t *testing.T) { + tests := []struct { + family nftables.TableFamily + want bool + }{ + {nftables.TableFamilyIPv4, true}, + {nftables.TableFamilyIPv6, true}, + {nftables.TableFamilyINet, true}, + {nftables.TableFamilyARP, false}, + {nftables.TableFamilyBridge, false}, + {nftables.TableFamilyNetdev, false}, + {nftables.TableFamilyUnspecified, false}, + } + for _, tc := range tests { + assert.Equal(t, tc.want, isMonitoredFamily(tc.family), "family=%d", tc.family) + } +} + +func TestIsRelevantMonitorEvent(t *testing.T) { + inetTable := &nftables.Table{Name: "firewalld", Family: nftables.TableFamilyINet} + ipTable := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} + arpTable := &nftables.Table{Name: "arp", Family: nftables.TableFamilyARP} + + tests := []struct { + name string + ev *nftables.MonitorEvent + want bool + }{ + { + name: "new chain in inet firewalld", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable}, + }, + want: true, + }, + { + name: "new chain in ip filter", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "INPUT", Table: ipTable}, + }, + want: true, + }, + { + name: "new chain in unwatched arp family", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "x", Table: arpTable}, + }, + want: false, + }, + { + name: "new table inet", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewTable, + Data: inetTable, + }, + want: true, + }, + { + name: "del chain (we only act on new)", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeDelChain, + Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable}, + }, + want: false, + }, + { + name: "chain with nil table", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "x"}, + }, + want: false, + }, + { + name: "nil data", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: (*nftables.Chain)(nil), + }, + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, isRelevantMonitorEvent(tc.ev)) + }) + } +} + +// fakeReconciler records reconcile invocations for debounce tests. +type fakeReconciler struct { + calls chan struct{} +} + +func (f *fakeReconciler) reconcileExternalChains() error { + f.calls <- struct{}{} + return nil +} + +func TestExternalChainMonitorStopWithoutStart(t *testing.T) { + m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)}) + // Must not panic or block. + m.stop() +} + +func TestExternalChainMonitorDoubleStart(t *testing.T) { + // start() twice should be a no-op; stop() cleans up once. + // We avoid exercising the netlink watch loop here because it needs root. + m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)}) + + // Replace run with a stub that just waits for cancel, so start() stays + // deterministic without opening a netlink socket. + origDone := make(chan struct{}) + m.done = origDone + m.cancel = func() { close(origDone) } + + // Second start should be a no-op (cancel already set). + m.start() + assert.NotNil(t, m.cancel) + + m.stop() + assert.Nil(t, m.cancel) + assert.Nil(t, m.done) +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 8cd5cc6b3..fdc7c2f3c 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -11,9 +11,11 @@ import ( "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" @@ -49,10 +51,17 @@ type Manager struct { rConn *nftables.Conn wgIface iFaceMapper - router *router - aclManager *AclManager + router *router + aclManager *AclManager + + // IPv6 counterparts, nil when no v6 overlay + router6 *router + aclManager6 *AclManager + notrackOutputChain *nftables.Chain notrackPreroutingChain *nftables.Chain + + extMonitor *externalChainMonitor } // Create nftables firewall manager @@ -62,7 +71,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { wgIface: wgIface, } - workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} + tableName := getTableName() + workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4} var err error m.router, err = newRouter(workTable, wgIface, mtu) @@ -75,35 +85,137 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { return nil, fmt.Errorf("create acl manager: %w", err) } + if wgIface.Address().HasIPv6() { + if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil { + return nil, fmt.Errorf("create IPv6 firewall: %w", err) + } + } + + m.extMonitor = newExternalChainMonitor(m) + return m, nil } +func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error { + workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6} + + var err error + m.router6, err = newRouter(workTable6, wgIface, mtu) + if err != nil { + return fmt.Errorf("create v6 router: %w", err) + } + + // Share the same IP forwarding state with the v4 router, since + // EnableIPForwarding controls both v4 and v6 sysctls. + m.router6.ipFwdState = m.router.ipFwdState + + m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw) + if err != nil { + return fmt.Errorf("create v6 acl manager: %w", err) + } + + return nil +} + +// hasIPv6 reports whether the manager has IPv6 components initialized. +func (m *Manager) hasIPv6() bool { + return m.router6 != nil +} + +func (m *Manager) initIPv6() error { + workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6) + if err != nil { + return fmt.Errorf("create v6 work table: %w", err) + } + + if err := m.router6.init(workTable6); err != nil { + return fmt.Errorf("v6 router init: %w", err) + } + + if err := m.aclManager6.init(workTable6); err != nil { + return fmt.Errorf("v6 acl manager init: %w", err) + } + + return nil +} + // Init nftables firewall manager func (m *Manager) Init(stateManager *statemanager.Manager) error { + if err := m.initFirewall(); err != nil { + return err + } + + m.persistState(stateManager) + + // Start after initFirewall has installed the baseline external-chain + // accept rules. start() is idempotent across Init/Close/Init cycles. + m.extMonitor.start() + + return nil +} + +// reconcileExternalChains re-applies passthrough accept rules to external +// filter chains for both IPv4 and IPv6 routers. Called by the monitor when +// tables or chains appear (e.g. after firewalld reloads). +func (m *Manager) reconcileExternalChains() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + var merr *multierror.Error + if m.router != nil { + if err := m.router.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("v4: %w", err)) + } + } + if m.hasIPv6() { + if err := m.router6.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("v6: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) +} + +func (m *Manager) initFirewall() (err error) { workTable, err := m.createWorkTable() if err != nil { return fmt.Errorf("create work table: %w", err) } + defer func() { + if err != nil { + m.rollbackInit() + } + }() + if err := m.router.init(workTable); err != nil { return fmt.Errorf("router init: %w", err) } if err := m.aclManager.init(workTable); err != nil { - // TODO: cleanup router return fmt.Errorf("acl manager init: %w", err) } + if m.hasIPv6() { + if err := m.initIPv6(); err != nil { + // Peer has a v6 address: v6 firewall MUST work or we risk fail-open. + return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err) + } + } + if err := m.initNoTrackChains(workTable); err != nil { log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) } + return nil +} + +// persistState saves the current interface state for potential recreation on restart. +// Unlike iptables, which requires tracking individual rules, nftables maintains +// a known state (our netbird table plus a few static rules). This allows for easy +// cleanup using Close() without needing to store specific rules. +func (m *Manager) persistState(stateManager *statemanager.Manager) { stateManager.RegisterState(&ShutdownState{}) - // We only need to record minimal interface state for potential recreation. - // Unlike iptables, which requires tracking individual rules, nftables maintains - // a known state (our netbird table plus a few static rules). This allows for easy - // cleanup using Close() without needing to store specific rules. if err := stateManager.UpdateState(&ShutdownState{ InterfaceState: &InterfaceState{ NameStr: m.wgIface.Name(), @@ -114,14 +226,29 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } - // persist early go func() { if err := stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } }() +} - return nil +// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through. +func (m *Manager) rollbackInit() { + if err := m.router.Reset(); err != nil { + log.Warnf("rollback router: %v", err) + } + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + log.Warnf("rollback v6 router: %v", err) + } + } + if err := m.cleanupNetbirdTables(); err != nil { + log.Warnf("cleanup tables: %v", err) + } + if err := m.rConn.Flush(); err != nil { + log.Warnf("flush: %v", err) + } } // AddPeerFiltering rule to the firewall @@ -140,12 +267,14 @@ func (m *Manager) AddPeerFiltering( m.mutex.Lock() defer m.mutex.Unlock() - rawIP := ip.To4() - if rawIP == nil { - return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) + if ip.To4() != nil { + return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } - return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) + if !m.hasIPv6() { + return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized) + } + return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) } func (m *Manager) AddRouteFiltering( @@ -159,8 +288,11 @@ func (m *Manager) AddRouteFiltering( m.mutex.Lock() defer m.mutex.Unlock() - if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { - return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) + if isIPv6RouteRule(sources, destination) { + if !m.hasIPv6() { + return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) @@ -171,15 +303,66 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() + if m.hasIPv6() && isIPv6Rule(rule) { + return m.aclManager6.DeletePeerRule(rule) + } return m.aclManager.DeletePeerRule(rule) } -// DeleteRouteRule deletes a routing rule +func isIPv6Rule(rule firewall.Rule) bool { + r, ok := rule.(*Rule) + return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6 +} + +// isIPv6RouteRule determines whether a route rule belongs to the v6 table. +// For static routes, the destination prefix determines the family. For dynamic +// routes (DomainSet), the sources determine the family since management +// duplicates dynamic rules per family. +func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool { + if destination.IsPrefix() { + return destination.Prefix.Addr().Is6() + } + return len(sources) > 0 && sources[0].Addr().Is6() +} + +// DeleteRouteRule deletes a routing rule. Route rules live in exactly one +// router; the cached maps are normally authoritative, so the kernel is only +// consulted when neither map knows about the rule. func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.DeleteRouteRule(rule) + id := rule.ID() + r, err := m.routerForRuleID(id, (*router).hasRule) + if err != nil { + return err + } + return r.DeleteRouteRule(rule) +} + +// routerForRuleID picks the router holding the rule with the given id, using +// the supplied lookup. If the cached maps disagree (or both miss), it refreshes +// from the kernel once and re-checks before falling back to the v4 router. +func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) { + if has(m.router, id) { + return m.router, nil + } + if m.hasIPv6() && has(m.router6, id) { + return m.router6, nil + } + if !m.hasIPv6() { + return m.router, nil + } + if err := m.router.refreshRulesMap(); err != nil { + return nil, fmt.Errorf("refresh v4 rules: %w", err) + } + if err := m.router6.refreshRulesMap(); err != nil { + return nil, fmt.Errorf("refresh v6 rules: %w", err) + } + if has(m.router6, id) && !has(m.router, id) { + return m.router6, nil + } + return m.router, nil } func (m *Manager) IsServerRouteSupported() bool { @@ -194,19 +377,70 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddNatRule(pair) + } + + if err := m.router.AddNatRule(pair); err != nil { + return err + } + + // Dynamic routes need NAT in both tables since resolved IPs can be + // either v4 or v6. This covers both DomainSet (modern) and the legacy + // wildcard 0.0.0.0/0 destination where the client resolves DNS. + // On v6 failure we keep the v4 NAT rule rather than rolling back: half + // connectivity is better than none, and RemoveNatRule is content-keyed + // so the eventual cleanup still works. + if m.hasIPv6() && pair.Dynamic { + v6Pair := firewall.ToV6NatPair(pair) + if err := m.router6.AddNatRule(v6Pair); err != nil { + return fmt.Errorf("add v6 NAT rule: %w", err) + } + } + + return nil } func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveNatRule(pair) + if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() { + if !m.hasIPv6() { + return nil + } + return m.router6.RemoveNatRule(pair) + } + + var merr *multierror.Error + + if err := m.router.RemoveNatRule(pair); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err)) + } + + if m.hasIPv6() && pair.Dynamic { + v6Pair := firewall.ToV6NatPair(pair) + if err := m.router6.RemoveNatRule(v6Pair); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic. // This is called when USPFilter wraps the native firewall, adding blanket accept // rules so that packet filtering is handled in userspace instead of by netfilter. +// +// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains, +// which doesn't override DROP rules in external tables (e.g. firewalld). +// Should add passthrough rules to external chains (like the native mode router's +// addExternalChainsRules does) for both the netbird table family and inet tables. +// The netbird table itself is fine (routing chains already exist there), but +// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic. func (m *Manager) AllowNetbird() error { m.mutex.Lock() defer m.mutex.Unlock() @@ -214,6 +448,11 @@ func (m *Manager) AllowNetbird() error { if err := m.aclManager.createDefaultAllowRules(); err != nil { return fmt.Errorf("create default allow rules: %w", err) } + if m.hasIPv6() { + if err := m.aclManager6.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create v6 default allow rules: %w", err) + } + } if err := m.rConn.Flush(); err != nil { return fmt.Errorf("flush allow input netbird rules: %w", err) } @@ -227,31 +466,47 @@ func (m *Manager) AllowNetbird() error { // SetLegacyManagement sets the route manager to use legacy management func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil { + return err + } + if m.hasIPv6() { + return firewall.SetLegacyManagement(m.router6, isLegacy) + } + return nil } // Close closes the firewall manager func (m *Manager) Close(stateManager *statemanager.Manager) error { + m.extMonitor.stop() + m.mutex.Lock() defer m.mutex.Unlock() + var merr *multierror.Error + if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset router: %v", err) + merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err)) + } + + if m.hasIPv6() { + if err := m.router6.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err)) + } } if err := m.cleanupNetbirdTables(); err != nil { - return fmt.Errorf("cleanup netbird tables: %v", err) + merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err)) } if err := m.rConn.Flush(); err != nil { - return fmt.Errorf(flushError, err) + merr = multierror.Append(merr, fmt.Errorf(flushError, err)) } if err := stateManager.DeleteState(&ShutdownState{}); err != nil { - return fmt.Errorf("delete state: %v", err) + merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err)) } - return nil + return nberrors.FormatErrorOrNil(merr) } func (m *Manager) cleanupNetbirdTables() error { @@ -300,6 +555,12 @@ func (m *Manager) Flush() error { return err } + if m.hasIPv6() { + if err := m.aclManager6.Flush(); err != nil { + return fmt.Errorf("flush v6 acl: %w", err) + } + } + if err := m.refreshNoTrackChains(); err != nil { log.Errorf("failed to refresh notrack chains: %v", err) } @@ -312,6 +573,12 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) m.mutex.Lock() defer m.mutex.Unlock() + if rule.TranslatedAddress.Is6() { + if !m.hasIPv6() { + return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddDNATRule(rule) + } return m.router.AddDNATRule(rule) } @@ -320,7 +587,11 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.DeleteDNATRule(rule) + r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule) + if err != nil { + return err + } + return r.DeleteDNATRule(rule) } // UpdateSet updates the set with the given prefixes @@ -328,39 +599,82 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.UpdateSet(set, prefixes) + var v4Prefixes, v6Prefixes []netip.Prefix + for _, p := range prefixes { + if p.Addr().Is6() { + v6Prefixes = append(v6Prefixes, p) + } else { + v4Prefixes = append(v4Prefixes, p) + } + } + + if err := m.router.UpdateSet(set, v4Prefixes); err != nil { + return err + } + + if m.hasIPv6() && len(v6Prefixes) > 0 { + if err := m.router6.UpdateSet(set, v6Prefixes); err != nil { + return fmt.Errorf("update v6 set: %w", err) + } + } + + return nil } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. -func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveInboundDNAT removes an inbound DNAT rule. -func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort) } // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. -func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. -func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) + if localAddr.Is6() { + if !m.hasIPv6() { + return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized) + } + return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) + } + return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) } const ( @@ -534,7 +848,11 @@ func (m *Manager) refreshNoTrackChains() error { } func (m *Manager) createWorkTable() (*nftables.Table, error) { - tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) + return m.createWorkTableFamily(nftables.TableFamilyIPv4) +} + +func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) { + tables, err := m.rConn.ListTablesOfFamily(family) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } @@ -546,7 +864,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { } } - table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family}) err = m.rConn.Flush() return table, err } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index d48e4ba88..be4f65881 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -383,10 +383,138 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { err = manager.AddNatRule(pair) require.NoError(t, err, "failed to add NAT rule") + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("100.96.0.2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "failed to add DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule") + }) + stdout, stderr = runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) } +func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} { + if _, err := exec.LookPath(bin); err != nil { + t.Skipf("%s not available on this system: %v", bin, err) + } + } + + // Seed ip6 tables in the nft backend. Docker may not create them. + seedIp6tables(t) + + ifaceMockV6 := &iFaceMock{ + NameFunc: func() string { return "wt-test" }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.96.0.1"), + Network: netip.MustParsePrefix("100.96.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + } + }, + } + + manager, err := Create(ifaceMockV6, iface.DefaultMTU) + require.NoError(t, err, "create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + require.NoError(t, manager.Close(nil), "close manager") + + stdout, stderr := runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) + }) + + ip := netip.MustParseAddr("fd00::2") + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") + require.NoError(t, err, "add v6 peer filtering rule") + + _, err = manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00:1::/64")}, + fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "add v6 route filtering rule") + + err = manager.AddNatRule(fw.RouterPair{ + Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")}, + Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")}, + Masquerade: true, + }) + require.NoError(t, err, "add v6 NAT rule") + + dnatRule, err := manager.AddDNATRule(fw.ForwardRule{ + Protocol: fw.ProtocolTCP, + DestinationPort: fw.Port{Values: []uint16{8080}}, + TranslatedAddress: netip.MustParseAddr("fd00::2"), + TranslatedPort: fw.Port{Values: []uint16{80}}, + }) + require.NoError(t, err, "add v6 DNAT rule") + + t.Cleanup(func() { + require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule") + }) + + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + stdout, stderr = runIp6tablesSave(t) + verifyIp6tablesOutput(t, stdout, stderr) +} + +func seedIp6tables(t *testing.T) { + t.Helper() + for _, tc := range []struct{ table, chain string }{ + {"filter", "FORWARD"}, + {"nat", "POSTROUTING"}, + {"mangle", "FORWARD"}, + } { + add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT") + require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table) + del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT") + require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table) + } +} + +func runIp6tablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("ip6tables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + require.NoError(t, cmd.Run(), "ip6tables-save failed") + return stdout.String(), stderr.String() +} + +func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + for _, msg := range []string{ + "Table `nat' is incompatible", + "Table `mangle' is incompatible", + "Table `filter' is incompatible", + } { + require.NotContains(t, stdout, msg, + "ip6tables-save stdout reports incompatibility: %s", stdout) + require.NotContains(t, stderr, msg, + "ip6tables-save stderr reports incompatibility: %s", stderr) + } +} + func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this system") diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 8cc0d2792..4214455a9 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -50,8 +50,10 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation. + ipv4TCPHeaderSize = 40 + // ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation. + ipv6TCPHeaderSize = 60 // maxPrefixesSet 1638 prefixes start to fail, taking some margin maxPrefixesSet = 1500 @@ -76,6 +78,7 @@ type router struct { rules map[string]*nftables.Rule ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] + af addrFamily wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool @@ -88,6 +91,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou workTable: workTable, chains: make(map[string]*nftables.Chain), rules: make(map[string]*nftables.Rule), + af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), mtu: mtu, @@ -150,7 +154,7 @@ func (r *router) Reset() error { func (r *router) removeNatPreroutingRules() error { table := &nftables.Table{ Name: tableNat, - Family: nftables.TableFamilyIPv4, + Family: r.af.tableFamily, } chain := &nftables.Chain{ Name: chainNameNatPrerouting, @@ -183,7 +187,7 @@ func (r *router) removeNatPreroutingRules() error { } func (r *router) loadFilterTable() (*nftables.Table, error) { - tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily) if err != nil { return nil, fmt.Errorf("list tables: %w", err) } @@ -419,7 +423,7 @@ func (r *router) AddRouteFiltering( // Handle protocol if proto != firewall.ProtocolALL { - protoNum, err := protoToInt(proto) + protoNum, err := r.af.protoNum(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -479,7 +483,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo return nil, fmt.Errorf("create or get ipset: %w", err) } - return getIpSetExprs(ref, isSource) + return r.getIpSetExprs(ref, isSource) +} + +func (r *router) iptablesProto() iptables.Protocol { + if r.af.tableFamily == nftables.TableFamilyIPv6 { + return iptables.ProtocolIPv6 + } + return iptables.ProtocolIPv4 +} + +func (r *router) hasRule(id string) bool { + _, ok := r.rules[id] + return ok +} + +func (r *router) hasDNATRule(id string) bool { + _, ok := r.rules[id+dnatSuffix] + return ok } func (r *router) DeleteRouteRule(rule firewall.Rule) error { @@ -528,10 +549,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err Table: r.workTable, // required for prefixes Interval: true, - KeyType: nftables.TypeIPAddr, + KeyType: r.af.setKeyType, } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) nElements := len(elements) maxElements := maxPrefixesSet * 2 @@ -564,23 +585,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err return nfset, nil } -func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { +func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { var elements []nftables.SetElement for _, prefix := range prefixes { - // TODO: Implement IPv6 support - if prefix.Addr().Is6() { - log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) - continue - } - // nftables needs half-open intervals [firstIP, lastIP) for prefixes // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc firstIP := prefix.Addr() lastIP := calculateLastIP(prefix).Next() elements = append(elements, - // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 - // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, + // the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247 + // nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true}, nftables.SetElement{Key: firstIP.AsSlice()}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, ) @@ -590,10 +605,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { // calculateLastIP determines the last IP in a given prefix. func calculateLastIP(prefix netip.Prefix) netip.Addr { - hostMask := ^uint32(0) >> prefix.Masked().Bits() - lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask + masked := prefix.Masked() + if masked.Addr().Is4() { + hostMask := ^uint32(0) >> masked.Bits() + lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask + return netip.AddrFrom4(uint32ToBytes(lastIP)) + } - return netip.AddrFrom4(uint32ToBytes(lastIP)) + // IPv6: set host bits to all 1s + b := masked.Addr().As16() + bits := masked.Bits() + for i := bits; i < 128; i++ { + b[i/8] |= 1 << (7 - i%8) + } + return netip.AddrFrom16(b) } // Utility function to convert netip.Addr to uint32. @@ -845,9 +870,16 @@ func (r *router) addPostroutingRules() { } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. -// TODO: Add IPv6 support func (r *router) addMSSClampingRules() error { - mss := r.mtu - ipTCPHeaderMinSize + overhead := uint16(ipv4TCPHeaderSize) + if r.af.tableFamily == nftables.TableFamilyIPv6 { + overhead = ipv6TCPHeaderSize + } + if r.mtu <= overhead { + log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead) + return nil + } + mss := r.mtu - overhead exprsOut := []expr.Any{ &expr.Meta{ @@ -1054,17 +1086,22 @@ func (r *router) acceptFilterTableRules() error { log.Debugf("Used %s to add accept forward and input rules", fw) }() - // Try iptables first and fallback to nftables if iptables is not available - ipt, err := iptables.New() + // Try iptables first and fallback to nftables if iptables is not available. + // Use the correct protocol (iptables vs ip6tables) for the address family. + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { - // iptables is not available but the filter table exists log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" return r.acceptFilterRulesNftables(r.filterTable) } - return r.acceptFilterRulesIptables(ipt) + if err := r.acceptFilterRulesIptables(ipt); err != nil { + log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err) + fw = "nftables" + return r.acceptFilterRulesNftables(r.filterTable) + } + return nil } func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { @@ -1135,83 +1172,122 @@ func (r *router) acceptExternalChainsRules() error { } intf := ifname(r.wgIface.Name()) - for _, chain := range chains { - if chain.Hooknum == nil { - log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) - continue - } - - log.Debugf("adding accept rules to external %s chain: %s %s/%s", - hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) - - switch *chain.Hooknum { - case *nftables.ChainHookForward: - r.insertForwardAcceptRules(chain, intf) - case *nftables.ChainHookInput: - r.insertInputAcceptRule(chain, intf) - } + r.applyExternalChainAccept(chain, intf) } if err := r.conn.Flush(); err != nil { return fmt.Errorf("flush external chain rules: %w", err) } - return nil } +func (r *router) applyExternalChainAccept(chain *nftables.Chain, intf []byte) { + if chain.Hooknum == nil { + log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) + return + } + + log.Debugf("adding accept rules to external %s chain: %s %s/%s", + hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) + + switch *chain.Hooknum { + case *nftables.ChainHookForward: + r.insertForwardAcceptRules(chain, intf) + case *nftables.ChainHookInput: + r.insertInputAcceptRule(chain, intf) + } +} + func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) { - iifRule := &nftables.Rule{ + existing, err := r.existingNetbirdRulesInChain(chain) + if err != nil { + log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err) + return + } + r.insertForwardIifRule(chain, intf, existing) + r.insertForwardOifEstablishedRule(chain, intf, existing) +} + +func (r *router) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) { + if existing[userDataAcceptForwardRuleIif] { + return + } + r.conn.InsertRule(&nftables.Rule{ Table: chain.Table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf}, &expr.Counter{}, &expr.Verdict{Kind: expr.VerdictAccept}, }, UserData: []byte(userDataAcceptForwardRuleIif), - } - r.conn.InsertRule(iifRule) + }) +} - oifExprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, +func (r *router) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) { + if existing[userDataAcceptForwardRuleOif] { + return } - oifRule := &nftables.Rule{ + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf}, + } + r.conn.InsertRule(&nftables.Rule{ Table: chain.Table, Chain: chain, - Exprs: append(oifExprs, getEstablishedExprs(2)...), + Exprs: append(exprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), - } - r.conn.InsertRule(oifRule) + }) } func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) { - inputRule := &nftables.Rule{ + existing, err := r.existingNetbirdRulesInChain(chain) + if err != nil { + log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err) + return + } + if existing[userDataAcceptInputRule] { + return + } + r.conn.InsertRule(&nftables.Rule{ Table: chain.Table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf}, &expr.Counter{}, &expr.Verdict{Kind: expr.VerdictAccept}, }, UserData: []byte(userDataAcceptInputRule), + }) +} + +// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive. +func (r *router) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) { + rules, err := r.conn.GetRules(chain.Table, chain) + if err != nil { + return nil, fmt.Errorf("list rules: %w", err) } - r.conn.InsertRule(inputRule) + present := map[string]bool{} + for _, rule := range rules { + if !isNetbirdAcceptRuleTag(rule.UserData) { + continue + } + present[string(rule.UserData)] = true + } + return present, nil +} + +func isNetbirdAcceptRuleTag(userData []byte) bool { + switch string(userData) { + case userDataAcceptForwardRuleIif, + userDataAcceptForwardRuleOif, + userDataAcceptInputRule: + return true + } + return false } func (r *router) removeAcceptFilterRules() error { @@ -1233,13 +1309,17 @@ func (r *router) removeFilterTableRules() error { return nil } - ipt, err := iptables.New() + ipt, err := iptables.NewWithProtocol(r.iptablesProto()) if err != nil { log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) return r.removeAcceptRulesFromTable(r.filterTable) } - return r.removeAcceptFilterRulesIptables(ipt) + if err := r.removeAcceptFilterRulesIptables(ipt); err != nil { + log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err) + return r.removeAcceptRulesFromTable(r.filterTable) + } + return nil } func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { @@ -1306,7 +1386,7 @@ func (r *router) removeExternalChainsRules() error { func (r *router) findExternalChains() []*nftables.Chain { var chains []*nftables.Chain - families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} + families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet} for _, family := range families { allChains, err := r.conn.ListChainsOfTableFamily(family) @@ -1337,8 +1417,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool { return false } - // Skip all iptables-managed tables in the ip family - if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { + // Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat) + if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) { return false } @@ -1479,7 +1559,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { return rule, nil } - protoNum, err := protoToInt(rule.Protocol) + protoNum, err := r.af.protoNum(rule.Protocol) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) } @@ -1542,7 +1622,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule dnatExprs = append(dnatExprs, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: regProtoMin, RegProtoMax: regProtoMax, @@ -1635,14 +1715,15 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f }, ) + natTable := &nftables.Table{ + Name: tableNat, + Family: r.af.tableFamily, + } dnatRule := &nftables.Rule{ - Table: &nftables.Table{ - Name: tableNat, - Family: nftables.TableFamilyIPv4, - }, + Table: natTable, Chain: &nftables.Chain{ Name: chainNameNatPrerouting, - Table: r.filterTable, + Table: natTable, Type: nftables.ChainTypeNAT, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityNATDest, @@ -1673,8 +1754,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, + Offset: r.af.dstAddrOffset, + Len: r.af.addrLen, }, &expr.Cmp{ Op: expr.CmpOpEq, @@ -1752,7 +1833,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return fmt.Errorf("get set %s: %w", set.HashedName(), err) } - elements := convertPrefixesToSet(prefixes) + elements := r.convertPrefixesToSet(prefixes) if err := r.conn.SetAddElements(nfset, elements); err != nil { return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) } @@ -1767,14 +1848,14 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { } // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. -func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if _, exists := r.rules[ruleID]; exists { return nil } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1801,11 +1882,15 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol &expr.Cmp{ Op: expr.CmpOpEq, Register: 3, - Data: binaryutil.BigEndian.PutUint16(sourcePort), + Data: binaryutil.BigEndian.PutUint16(originalPort), }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1814,11 +1899,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol }, &expr.Immediate{ Register: 2, - Data: binaryutil.BigEndian.PutUint16(targetPort), + Data: binaryutil.BigEndian.PutUint16(translatedPort), }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, RegProtoMax: 0, @@ -1843,12 +1928,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol } // RemoveInboundDNAT removes an inbound DNAT rule. -func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } - ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) rule, exists := r.rules[ruleID] if !exists { @@ -1894,8 +1979,8 @@ func (r *router) ensureNATOutputChain() error { } // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. -func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { - ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) +func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) if _, exists := r.rules[ruleID]; exists { return nil @@ -1905,7 +1990,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, return err } - protoNum, err := protoToInt(protocol) + protoNum, err := r.af.protoNum(protocol) if err != nil { return fmt.Errorf("convert protocol to number: %w", err) } @@ -1926,11 +2011,15 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, &expr.Cmp{ Op: expr.CmpOpEq, Register: 2, - Data: binaryutil.BigEndian.PutUint16(sourcePort), + Data: binaryutil.BigEndian.PutUint16(originalPort), }, } - exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + bits := 32 + if localAddr.Is6() { + bits = 128 + } + exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...) exprs = append(exprs, &expr.Immediate{ @@ -1939,11 +2028,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, }, &expr.Immediate{ Register: 2, - Data: binaryutil.BigEndian.PutUint16(targetPort), + Data: binaryutil.BigEndian.PutUint16(translatedPort), }, &expr.NAT{ Type: expr.NATTypeDestNAT, - Family: uint32(nftables.TableFamilyIPv4), + Family: uint32(r.af.tableFamily), RegAddrMin: 1, RegProtoMin: 2, }, @@ -1967,12 +2056,12 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, } // RemoveOutputDNAT removes an OUTPUT chain DNAT rule. -func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } - ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort) rule, exists := r.rules[ruleID] if !exists { @@ -2011,45 +2100,44 @@ func (r *router) applyNetwork( } if network.IsPrefix() { - return applyPrefix(network.Prefix, isSource), nil + return r.applyPrefix(network.Prefix, isSource), nil } return nil, nil } // applyPrefix generates nftables expressions for a CIDR prefix -func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { - // dst offset - offset := uint32(16) +func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } ones := prefix.Bits() - // 0.0.0.0/0 doesn't need extra expressions + // unspecified address (/0) doesn't need extra expressions if ones == 0 { return nil } - mask := net.CIDRMask(ones, 32) + mask := net.CIDRMask(ones, r.af.totalBits) + xor := make([]byte, r.af.addrLen) return []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, - // netmask &expr.Bitwise{ DestRegister: 1, SourceRegister: 1, - Len: 4, + Len: r.af.addrLen, Mask: mask, - Xor: []byte{0, 0, 0, 0}, + Xor: xor, }, - // net address &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, @@ -2132,13 +2220,12 @@ func getCtNewExprs() []expr.Any { } } -func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { - - // dst offset - offset := uint32(16) +func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { + // dst offset by default + offset := r.af.dstAddrOffset if isSource { // src offset - offset = 12 + offset = r.af.srcAddrOffset } return []expr.Any{ @@ -2146,7 +2233,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: offset, - Len: 4, + Len: r.af.addrLen, }, &expr.Lookup{ SourceRegister: 1, diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index f0e34d211..c5d6729d9 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) { } // Build CIDR matching expressions - sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) - destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) + testRouter := &router{af: afIPv4} + sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true) + destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false) // Combine all expressions in the correct order // nolint:gocritic @@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) { } } +func TestNftablesCreateIpSet_IPv6(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTableIPv6() + require.NoError(t, err, "Failed to create v6 work table") + defer deleteWorkTableIPv6() + + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) + require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) + defer func() { + require.NoError(t, r.Reset(), "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Single IPv6", + sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")}, + }, + { + name: "Multiple IPv6 Subnets", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::/48"), + netip.MustParsePrefix("fe80::/10"), + }, + }, + { + name: "Overlapping IPv6", + sources: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("fd00::1/128"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("fd00::/48"), + }, + }, + { + name: "Mixed prefix lengths", + sources: []netip.Prefix{ + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::1/128"), + netip.MustParsePrefix("fd00:abcd::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setName := firewall.NewPrefixSet(tt.sources).HashedName() + set, err := r.createIpSet(setName, setInput{prefixes: tt.sources}) + require.NoError(t, err, "Failed to create IPv6 set") + require.NotNil(t, set) + + assert.Equal(t, setName, set.Name) + assert.True(t, set.Interval) + assert.Equal(t, nftables.TypeIP6Addr, set.KeyType) + + fetchedSet, err := r.conn.GetSetByName(r.workTable, setName) + require.NoError(t, err, "Failed to fetch created set") + + elements, err := r.conn.GetSetElements(fetchedSet) + require.NoError(t, err, "Failed to get set elements") + + uniquePrefixes := make(map[string]bool) + for _, elem := range elements { + if !elem.IntervalEnd && len(elem.Key) == 16 { + ip := netip.AddrFrom16([16]byte(elem.Key)) + uniquePrefixes[ip.String()] = true + } + } + + expectedCount := len(tt.expected) + if expectedCount == 0 { + expectedCount = len(tt.sources) + } + assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch") + + r.conn.DelSet(set) + require.NoError(t, r.conn.Flush()) + }) + } +} + +func createWorkTableIPv6() (*nftables.Table, error) { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return nil, err + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return nil, err + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + } + } + + table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6}) + err = sConn.Flush() + return table, err +} + +func deleteWorkTableIPv6() { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6) + if err != nil { + return + } + for _, t := range tables { + if t.Name == tableNameNetbird { + sConn.DelTable(t) + _ = sConn.Flush() + } + } +} + func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { t.Helper() @@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { var metaFound, cmpFound bool - expectedProto, _ := protoToInt(proto) + expectedProto, _ := afIPv4.protoNum(proto) for _, e := range exprs { switch ex := e.(type) { case *expr.Meta: @@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) { } assert.Equal(t, 1, found, "NAT rule should exist in kernel") } + +func TestCalculateLastIP(t *testing.T) { + tests := []struct { + prefix string + want string + }{ + {"10.0.0.0/24", "10.0.0.255"}, + {"10.0.0.0/32", "10.0.0.0"}, + {"0.0.0.0/0", "255.255.255.255"}, + {"192.168.1.0/28", "192.168.1.15"}, + {"fd00::/64", "fd00::ffff:ffff:ffff:ffff"}, + {"fd00::/128", "fd00::"}, + {"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"}, + {"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"}, + } + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + prefix := netip.MustParsePrefix(tt.prefix) + got := calculateLastIP(prefix) + assert.Equal(t, tt.want, got.String()) + }) + } +} + +func TestConvertPrefixesToSet_IPv6(t *testing.T) { + r := &router{af: afIPv6} + prefixes := []netip.Prefix{ + netip.MustParsePrefix("fd00::/64"), + netip.MustParsePrefix("2001:db8::1/128"), + } + + elements := r.convertPrefixesToSet(prefixes) + + // Each prefix produces 2 elements (start + end) + require.Len(t, elements, 4) + + // fd00::/64 start + assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key)) + assert.False(t, elements[0].IntervalEnd) + + // fd00::/64 end (fd00:0:0:1::, one past the last) + assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key)) + assert.True(t, elements[1].IntervalEnd) + + // 2001:db8::1/128 start + assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key)) + assert.False(t, elements[2].IntervalEnd) + + // 2001:db8::1/128 end (2001:db8::2) + assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key)) + assert.True(t, elements[3].IntervalEnd) +} diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 6aef2ecfd..10a2b9116 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -5,8 +5,10 @@ import ( "os/exec" "syscall" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error { return nil } - if !isFirewallRuleActive(firewallRuleName) { - return nil + var merr *multierror.Error + if isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err)) + } } - if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { - return fmt.Errorf("couldn't remove windows firewall: %w", err) + if isFirewallRuleActive(firewallRuleName + "-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err)) + } } - return nil + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic @@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error { return nil } - if isFirewallRuleActive(firewallRuleName) { - return nil + if !isFirewallRuleActive(firewallRuleName) { + if err := manageFirewallRule(firewallRuleName, + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+m.wgIface.Address().IP.String(), + ); err != nil { + return err + } } - return manageFirewallRule(firewallRuleName, - addRule, - "dir=in", - "enable=yes", - "action=allow", - "profile=any", - "localip="+m.wgIface.Address().IP.String(), - ) + + if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") { + if err := manageFirewallRule(firewallRuleName+"-v6", + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+v6.String(), + ); err != nil { + return err + } + } + + return nil } func manageFirewallRule(ruleName string, action action, extraArgs ...string) error { diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 7be0dd78f..88e90317c 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -1,8 +1,9 @@ package conntrack import ( - "fmt" + "net" "net/netip" + "strconv" "sync/atomic" "time" @@ -64,5 +65,7 @@ type ConnKey struct { } func (c ConnKey) String() string { - return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) + return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) + + " → " + + net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort))) } diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index d868dd1fb..7e67b98fa 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -13,6 +13,54 @@ import ( var logger = log.NewFromLogrus(logrus.StandardLogger()) var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() +func TestConnKey_String(t *testing.T) { + tests := []struct { + name string + key ConnKey + expect string + }{ + { + name: "IPv4", + key: ConnKey{ + SrcIP: netip.MustParseAddr("192.168.1.1"), + DstIP: netip.MustParseAddr("10.0.0.1"), + SrcPort: 12345, + DstPort: 80, + }, + expect: "192.168.1.1:12345 → 10.0.0.1:80", + }, + { + name: "IPv6", + key: ConnKey{ + SrcIP: netip.MustParseAddr("2001:db8::1"), + DstIP: netip.MustParseAddr("2001:db8::2"), + SrcPort: 54321, + DstPort: 443, + }, + expect: "[2001:db8::1]:54321 → [2001:db8::2]:443", + }, + { + name: "IPv4-mapped IPv6 unmaps", + key: ConnKey{ + SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"), + DstIP: netip.MustParseAddr("::ffff:10.0.0.2"), + SrcPort: 1000, + DstPort: 2000, + }, + expect: "10.0.0.1:1000 → 10.0.0.2:2000", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.key.String() + if got != tc.expect { + t.Errorf("got %q, want %q", got, tc.expect) + } + }) + } +} + // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) { diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 50b663642..a48215ca9 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "strconv" "sync" "time" @@ -21,9 +22,14 @@ const ( // ICMPCleanupInterval is how often we check for stale ICMP connections ICMPCleanupInterval = 15 * time.Second - // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info, - // which includes the IP header (20 bytes) and transport header (8 bytes) - MaxICMPPayloadLength = 28 + // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info. + // IPv4: 20-byte header + 8-byte transport = 28 bytes. + // IPv6: 40-byte header + 8-byte transport = 48 bytes. + MaxICMPPayloadLength = 48 + // minICMPPayloadIPv4 is the minimum embedded packet length for IPv4 ICMP errors. + minICMPPayloadIPv4 = 28 + // minICMPPayloadIPv6 is the minimum embedded packet length for IPv6 ICMP errors. + minICMPPayloadIPv6 = 48 ) // ICMPConnKey uniquely identifies an ICMP connection @@ -65,7 +71,7 @@ type ICMPInfo struct { // String implements fmt.Stringer for lazy evaluation in log messages func (info ICMPInfo) String() string { - if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength { + if info.isErrorMessage() && info.PayloadLen >= minICMPPayloadIPv4 { if origInfo := info.parseOriginalPacket(); origInfo != "" { return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo) } @@ -74,42 +80,72 @@ func (info ICMPInfo) String() string { return info.TypeCode.String() } -// isErrorMessage returns true if this ICMP type carries original packet info +// isErrorMessage returns true if this ICMP type carries original packet info. +// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match +// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's +// kept as a literal. func (info ICMPInfo) isErrorMessage() bool { typ := info.TypeCode.Type() - return typ == 3 || // Destination Unreachable - typ == 5 || // Redirect - typ == 11 || // Time Exceeded - typ == 12 // Parameter Problem + // ICMPv4 error types + if typ == layers.ICMPv4TypeDestinationUnreachable || + typ == layers.ICMPv4TypeRedirect || + typ == layers.ICMPv4TypeTimeExceeded || + typ == layers.ICMPv4TypeParameterProblem { + return true + } + // ICMPv6 error types (type 3 already matched above as v4 DestUnreachable) + if typ == layers.ICMPv6TypeDestinationUnreachable || + typ == layers.ICMPv6TypePacketTooBig || + typ == layers.ICMPv6TypeParameterProblem { + return true + } + return false } // parseOriginalPacket extracts info about the original packet from ICMP payload func (info ICMPInfo) parseOriginalPacket() string { - if info.PayloadLen < MaxICMPPayloadLength { + if info.PayloadLen == 0 { return "" } - // TODO: handle IPv6 - if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 { + version := (info.PayloadData[0] >> 4) & 0xF + + var protocol uint8 + var srcIP, dstIP net.IP + var transportData []byte + + switch version { + case 4: + if info.PayloadLen < minICMPPayloadIPv4 { + return "" + } + protocol = info.PayloadData[9] + srcIP = net.IP(info.PayloadData[12:16]) + dstIP = net.IP(info.PayloadData[16:20]) + transportData = info.PayloadData[20:] + case 6: + if info.PayloadLen < minICMPPayloadIPv6 { + return "" + } + // Next Header field in IPv6 header + protocol = info.PayloadData[6] + srcIP = net.IP(info.PayloadData[8:24]) + dstIP = net.IP(info.PayloadData[24:40]) + transportData = info.PayloadData[40:] + default: return "" } - protocol := info.PayloadData[9] - srcIP := net.IP(info.PayloadData[12:16]) - dstIP := net.IP(info.PayloadData[16:20]) - - transportData := info.PayloadData[20:] - switch nftypes.Protocol(protocol) { case nftypes.TCP: srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) - return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort) + return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))) case nftypes.UDP: srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) - return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort) + return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))) case nftypes.ICMP: icmpType := transportData[0] @@ -247,9 +283,10 @@ func (t *ICMPTracker) track( t.sendEvent(nftypes.TypeStart, conn, ruleId) } -// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request +// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request. +// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies. func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool { - if icmpType != uint8(layers.ICMPv4TypeEchoReply) { + if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) { return false } @@ -301,6 +338,13 @@ func (t *ICMPTracker) cleanup() { } } +func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol { + if ip.Is6() { + return nftypes.ICMPv6 + } + return nftypes.ICMP +} + // Close stops the cleanup routine and releases resources func (t *ICMPTracker) Close() { t.tickerCancel() @@ -316,7 +360,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID [] Type: typ, RuleID: ruleID, Direction: conn.Direction, - Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 + Protocol: icmpProtocolForAddr(conn.SourceIP), SourceIP: conn.SourceIP, DestIP: conn.DestIP, ICMPType: conn.ICMPType, @@ -334,7 +378,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad Type: nftypes.TypeStart, RuleID: ruleID, Direction: direction, - Protocol: nftypes.ICMP, + Protocol: icmpProtocolForAddr(srcIP), SourceIP: srcIP, DestIP: dstIP, ICMPType: typ, diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index b15b42cf0..6d1f87162 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -5,6 +5,42 @@ import ( "testing" ) +func TestICMPConnKey_String(t *testing.T) { + tests := []struct { + name string + key ICMPConnKey + expect string + }{ + { + name: "IPv4", + key: ICMPConnKey{ + SrcIP: netip.MustParseAddr("192.168.1.1"), + DstIP: netip.MustParseAddr("10.0.0.1"), + ID: 1234, + }, + expect: "192.168.1.1 → 10.0.0.1 (id 1234)", + }, + { + name: "IPv6", + key: ICMPConnKey{ + SrcIP: netip.MustParseAddr("2001:db8::1"), + DstIP: netip.MustParseAddr("2001:db8::2"), + ID: 5678, + }, + expect: "2001:db8::1 → 2001:db8::2 (id 5678)", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.key.String() + if got != tc.expect { + t.Errorf("got %q, want %q", got, tc.expect) + } + }) + } +} + func BenchmarkICMPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 3787e63a8..5ecd08abf 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -18,9 +18,10 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/uuid" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" @@ -35,8 +36,10 @@ import ( const ( layerTypeAll = 255 - // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation - ipTCPHeaderMinSize = 40 + // ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation + ipv4TCPHeaderMinSize = 40 + // ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation + ipv6TCPHeaderMinSize = 60 ) // serviceKey represents a protocol/port combination for netstack service registry @@ -123,7 +126,7 @@ type Manager struct { logger *nblog.Logger flowLogger nftypes.FlowLogger - blockRule firewall.Rule + blockRules []firewall.Rule // Internal 1:1 DNAT dnatEnabled atomic.Bool @@ -138,9 +141,10 @@ type Manager struct { netstackServices map[serviceKey]struct{} netstackServiceMutex sync.RWMutex - mtu uint16 - mssClampValue uint16 - mssClampEnabled bool + mtu uint16 + mssClampValueIPv4 uint16 + mssClampValueIPv6 uint16 + mssClampEnabled bool // Only one hook per protocol is supported. Outbound direction only. udpHookOut atomic.Pointer[common.PacketHook] @@ -157,11 +161,28 @@ type decoder struct { icmp4 layers.ICMPv4 icmp6 layers.ICMPv6 decoded []gopacket.LayerType - parser *gopacket.DecodingLayerParser + parser4 *gopacket.DecodingLayerParser + parser6 *gopacket.DecodingLayerParser dnatOrigPort uint16 } +// decodePacket decodes packet data using the appropriate parser based on IP version. +func (d *decoder) decodePacket(data []byte) error { + if len(data) == 0 { + return errors.New("empty packet") + } + version := data[0] >> 4 + switch version { + case 4: + return d.parser4.DecodeLayers(data, &d.decoded) + case 6: + return d.parser6.DecodeLayers(data, &d.decoded) + default: + return fmt.Errorf("unknown IP version %d", version) + } +} + // Create userspace firewall manager constructor func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { return create(iface, nil, disableServerRoutes, flowLogger, mtu) @@ -219,11 +240,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true return d }, }, @@ -249,7 +276,12 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe if !disableMSSClamping { m.mssClampEnabled = true - m.mssClampValue = mtu - ipTCPHeaderMinSize + if mtu > ipv4TCPHeaderMinSize { + m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize + } + if mtu > ipv6TCPHeaderMinSize { + m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize + } } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { return nil, fmt.Errorf("update local IPs: %w", err) @@ -272,13 +304,25 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe return m, nil } -func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { +// blockInvalidRouted installs drop rules for traffic to the wg overlay that +// arrives via the routing path. v4 and v6 are independent: a v6 install +// failure leaves v4 protection in place (and vice versa) so the returned +// slice always contains whatever was successfully installed, even on error. +// Callers must persist the slice so DisableRouting can clean partial state. +func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) { wgPrefix := iface.Address().Network log.Debugf("blocking invalid routed traffic for %s", wgPrefix) - rule, err := m.addRouteFiltering( + sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)} + v6Net := iface.Address().IPv6Net + if v6Net.IsValid() { + sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0)) + } + + var rules []firewall.Rule + v4Rule, err := m.addRouteFiltering( nil, - []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, + sources, firewall.Network{Prefix: wgPrefix}, firewall.ProtocolALL, nil, @@ -286,12 +330,30 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e firewall.ActionDrop, ) if err != nil { - return nil, fmt.Errorf("block wg nte : %w", err) + return rules, fmt.Errorf("block wg v4 net: %w", err) + } + rules = append(rules, v4Rule) + + if v6Net.IsValid() { + log.Debugf("blocking invalid routed traffic for %s", v6Net) + v6Rule, err := m.addRouteFiltering( + nil, + sources, + firewall.Network{Prefix: v6Net}, + firewall.ProtocolALL, + nil, + nil, + firewall.ActionDrop, + ) + if err != nil { + return rules, fmt.Errorf("block wg v6 net: %w", err) + } + rules = append(rules, v6Rule) } // TODO: Block networks that we're a client of - return rule, nil + return rules, nil } func (m *Manager) determineRouting() error { @@ -521,7 +583,7 @@ func (m *Manager) addRouteFiltering( mgmtId: id, sources: sources, dstSet: destination.Set, - protoLayer: protoToLayer(proto, layers.LayerTypeIPv4), + protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)), srcPort: sPort, dstPort: dPort, action: action, @@ -612,10 +674,10 @@ func (m *Manager) Flush() error { return nil } // resetState clears all firewall rules and closes connection trackers. // Must be called with m.mutex held. func (m *Manager) resetState() { - maps.Clear(m.outgoingRules) - maps.Clear(m.incomingDenyRules) - maps.Clear(m.incomingRules) - maps.Clear(m.routeRulesMap) + clear(m.outgoingRules) + clear(m.incomingDenyRules) + clear(m.incomingRules) + clear(m.routeRulesMap) m.routeRules = m.routeRules[:0] m.udpHookOut.Store(nil) m.tcpHookOut.Store(nil) @@ -676,11 +738,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { } destinations := matches[0].destinations - for _, prefix := range prefixes { - if prefix.Addr().Is4() { - destinations = append(destinations, prefix) - } - } + destinations = append(destinations, prefixes...) slices.SortFunc(destinations, func(a, b netip.Prefix) int { cmp := a.Addr().Compare(b.Addr()) @@ -719,7 +777,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { return false } @@ -803,12 +861,32 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { return false } + var mssClampValue uint16 + var ipHeaderSize int + switch d.decoded[0] { + case layers.LayerTypeIPv4: + mssClampValue = m.mssClampValueIPv4 + ipHeaderSize = int(d.ip4.IHL) * 4 + if ipHeaderSize < 20 { + return false + } + case layers.LayerTypeIPv6: + mssClampValue = m.mssClampValueIPv6 + ipHeaderSize = 40 + default: + return false + } + + if mssClampValue == 0 { + return false + } + mssOptionIndex := -1 var currentMSS uint16 for i, opt := range d.tcp.Options { if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { currentMSS = binary.BigEndian.Uint16(opt.OptionData) - if currentMSS > m.mssClampValue { + if currentMSS > mssClampValue { mssOptionIndex = i break } @@ -819,20 +897,15 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { return false } - ipHeaderSize := int(d.ip4.IHL) * 4 - if ipHeaderSize < 20 { + if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) { return false } - if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { - return false - } - - m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue) return true } -func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { +func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool { tcpHeaderStart := ipHeaderSize tcpOptionsStart := tcpHeaderStart + 20 @@ -847,7 +920,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, } mssValueOffset := optOffset + 2 - binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) + binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue) m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) return true @@ -857,18 +930,32 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade tcpLayer := packetData[tcpHeaderStart:] tcpLength := len(packetData) - tcpHeaderStart + // Zero out existing checksum tcpLayer[16] = 0 tcpLayer[17] = 0 + // Build pseudo-header checksum based on IP version var pseudoSum uint32 - pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) - pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) - pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) - pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) - pseudoSum += uint32(d.ip4.Protocol) - pseudoSum += uint32(tcpLength) + switch d.decoded[0] { + case layers.LayerTypeIPv4: + pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) + pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) + pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) + pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) + pseudoSum += uint32(d.ip4.Protocol) + pseudoSum += uint32(tcpLength) + case layers.LayerTypeIPv6: + for i := 0; i < 16; i += 2 { + pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1]) + } + for i := 0; i < 16; i += 2 { + pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1]) + } + pseudoSum += uint32(tcpLength) + pseudoSum += uint32(layers.IPProtocolTCP) + } - var sum = pseudoSum + sum := pseudoSum for i := 0; i < tcpLength-1; i += 2 { sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) } @@ -906,6 +993,9 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) + case layers.LayerTypeICMPv6: + id, tc := icmpv6EchoFields(d) + m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size) } } @@ -919,6 +1009,9 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) + case layers.LayerTypeICMPv6: + id, tc := icmpv6EchoFields(d) + m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size) } d.dnatOrigPort = 0 @@ -951,15 +1044,19 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { // TODO: pass fragments of routed packets to forwarder if fragment { - m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", - srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + if d.decoded[0] == layers.LayerTypeIPv4 { + m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", + srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + } else { + m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP) + } return false } // TODO: optimize port DNAT by caching matched rules in conntrack if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { // Re-decode after port DNAT translation to update port information - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) return true } @@ -968,7 +1065,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } @@ -1100,6 +1197,48 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe return true } +// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps +// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle +// both families uniformly. The echo ID is in the first two payload bytes. +func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) { + if len(d.icmp6.Payload) >= 2 { + id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1]) + } + // Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking. + switch d.icmp6.TypeCode.Type() { + case layers.ICMPv6TypeEchoRequest: + tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0) + case layers.ICMPv6TypeEchoReply: + tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0) + default: + tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code()) + } + return id, tc +} + +// protoLayerMatches checks if a packet's protocol layer matches a rule's expected +// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching +// ICMP rules since management sends a single ICMP rule for both families. +func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool { + if ruleLayer == packetLayer { + return true + } + if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 { + return true + } + if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 { + return true + } + return false +} + +func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType { + if p.Addr().Is6() { + return layers.LayerTypeIPv6 + } + return layers.LayerTypeIPv4 +} + func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType { switch proto { case firewall.ProtocolTCP: @@ -1123,8 +1262,10 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol { return nftypes.TCP case layers.LayerTypeUDP: return nftypes.UDP - case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: + case layers.LayerTypeICMPv4: return nftypes.ICMP + case layers.LayerTypeICMPv6: + return nftypes.ICMPv6 default: return nftypes.ProtocolUnknown } @@ -1145,7 +1286,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { // It returns true, false if the packet is valid and not a fragment. // It returns true, true if the packet is a fragment and valid. func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { m.logger.Trace1("couldn't decode packet, err: %s", err) return false, false } @@ -1158,10 +1299,21 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { } // Fragments are also valid - if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 { - ip4 := d.ip4 - if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 { - return true, true + if l == 1 { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 { + return true, true + } + case layers.LayerTypeIPv6: + // IPv6 uses Fragment extension header (NextHeader=44). If gopacket + // only decoded the IPv6 layer, the transport is in a fragment. + // TODO: handle non-Fragment extension headers (HopByHop, Routing, + // DestOpts) by walking the chain. gopacket's parser does not + // support them as DecodingLayers; today we drop such packets. + if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment { + return true, true + } } } @@ -1199,21 +1351,35 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size, ) - // TODO: ICMPv6 + case layers.LayerTypeICMPv6: + id, _ := icmpv6EchoFields(d) + return m.icmpTracker.IsValidInbound( + srcIP, + dstIP, + id, + d.icmp6.TypeCode.Type(), + size, + ) } return false } -// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed +// isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed. func (m *Manager) isSpecialICMP(d *decoder) bool { - if d.decoded[1] != layers.LayerTypeICMPv4 { - return false + switch d.decoded[1] { + case layers.LayerTypeICMPv4: + icmpType := d.icmp4.TypeCode.Type() + return icmpType == layers.ICMPv4TypeDestinationUnreachable || + icmpType == layers.ICMPv4TypeTimeExceeded + case layers.LayerTypeICMPv6: + icmpType := d.icmp6.TypeCode.Type() + return icmpType == layers.ICMPv6TypeDestinationUnreachable || + icmpType == layers.ICMPv6TypePacketTooBig || + icmpType == layers.ICMPv6TypeTimeExceeded || + icmpType == layers.ICMPv6TypeParameterProblem } - - icmpType := d.icmp4.TypeCode.Type() - return icmpType == layers.ICMPv4TypeDestinationUnreachable || - icmpType == layers.ICMPv4TypeTimeExceeded + return false } func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) { @@ -1270,7 +1436,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d return rule.mgmtId, rule.drop, true } - if payloadLayer != rule.protoLayer { + if !protoLayerMatches(rule.protoLayer, payloadLayer) { continue } @@ -1305,8 +1471,7 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay } func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool { - // TODO: handle ipv6 vs ipv4 icmp rules - if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer { + if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) { return false } @@ -1367,13 +1532,14 @@ func (m *Manager) EnableRouting() error { return nil } - rule, err := m.blockInvalidRouted(m.wgIface) + rules, err := m.blockInvalidRouted(m.wgIface) + // Persist whatever was installed even on partial failure, so DisableRouting + // can clean it up later. + m.blockRules = rules if err != nil { return fmt.Errorf("block invalid routed: %w", err) } - m.blockRule = rule - return nil } @@ -1389,9 +1555,16 @@ func (m *Manager) DisableRouting() error { m.routingEnabled.Store(false) m.nativeRouter.Store(false) - // don't stop forwarder if in use by netstack + var merr *multierror.Error + for _, rule := range m.blockRules { + if err := m.deleteRouteRule(rule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete block rule: %w", err)) + } + } + m.blockRules = nil + if m.netstack && m.localForwarding { - return nil + return nberrors.FormatErrorOrNil(merr) } fwder.Stop() @@ -1399,14 +1572,7 @@ func (m *Manager) DisableRouting() error { log.Debug("forwarder stopped") - if m.blockRule != nil { - if err := m.deleteRouteRule(m.blockRule); err != nil { - return fmt.Errorf("delete block rule: %w", err) - } - m.blockRule = nil - } - - return nil + return nberrors.FormatErrorOrNil(merr) } // RegisterNetstackService registers a service as listening on the netstack for the given protocol and port @@ -1460,7 +1626,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { } // traffic to our other local interfaces (not NetBird IP) - always forward - if dstIP != m.wgIface.Address().IP { + addr := m.wgIface.Address() + if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) { return true } diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 10ff62ed3..4dccb0f65 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -1023,7 +1023,8 @@ func BenchmarkMSSClamping(b *testing.B) { }() manager.mssClampEnabled = true - manager.mssClampValue = 1240 + manager.mssClampValueIPv4 = 1240 + manager.mssClampValueIPv6 = 1220 srcIP := net.ParseIP("100.64.0.2") dstIP := net.ParseIP("8.8.8.8") @@ -1088,7 +1089,8 @@ func BenchmarkMSSClampingOverhead(b *testing.B) { manager.mssClampEnabled = sc.enabled if sc.enabled { - manager.mssClampValue = 1240 + manager.mssClampValueIPv4 = 1240 + manager.mssClampValueIPv6 = 1220 } srcIP := net.ParseIP("100.64.0.2") @@ -1141,7 +1143,8 @@ func BenchmarkMSSClampingMemory(b *testing.B) { }() manager.mssClampEnabled = true - manager.mssClampValue = 1240 + manager.mssClampValueIPv4 = 1240 + manager.mssClampValueIPv6 = 1220 srcIP := net.ParseIP("100.64.0.2") dstIP := net.ParseIP("8.8.8.8") diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index a8efbac1c..a64c83138 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -539,53 +539,236 @@ func TestPeerACLFiltering(t *testing.T) { } } +func TestPeerACLFilteringIPv6(t *testing.T) { + localIP := netip.MustParseAddr("100.10.0.100") + localIPv6 := netip.MustParseAddr("fd00::100") + wgNet := netip.MustParsePrefix("100.10.0.0/16") + wgNetV6 := netip.MustParsePrefix("fd00::/64") + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: localIP, + Network: wgNet, + IPv6: localIPv6, + IPv6Net: wgNetV6, + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + testCases := []struct { + name string + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + ruleIP string + ruleProto fw.Protocol + ruleDstPort *fw.Port + ruleAction fw.Action + shouldBeBlocked bool + }{ + { + name: "IPv6: allow TCP from peer", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: allow UDP from peer", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{Values: []uint16{53}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: allow ICMPv6 from peer", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolICMP, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: block TCP without rule", + srcIP: "fd00::2", + dstIP: "fd00::100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "IPv6: drop rule", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 22, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{22}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "IPv6: allow all protocols", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 9999, + ruleIP: "fd00::1", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches", + srcIP: "fd00::1", + dstIP: "fd00::100", + proto: fw.ProtocolICMP, + ruleIP: "0.0.0.0", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + } + + t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) { + packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443) + isDropped := manager.FilterInbound(packet, 0) + require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist") + }) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.ruleAction == fw.ActionDrop { + rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "") + require.NoError(t, err) + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + } + + rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "") + require.NoError(t, err) + require.NotEmpty(t, rules) + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + + packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) + isDropped := manager.FilterInbound(packet, 0) + require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch") + }) + } +} + func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte { t.Helper() + src := net.ParseIP(srcIP) + dst := net.ParseIP(dstIP) + buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ ComputeChecksums: true, FixLengths: true, } - ipLayer := &layers.IPv4{ - Version: 4, - TTL: 64, - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - } + // Detect address family + isV6 := src.To4() == nil var err error - switch proto { - case fw.ProtocolTCP: - ipLayer.Protocol = layers.IPProtocolTCP - tcp := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - } - err = tcp.SetNetworkLayerForChecksum(ipLayer) - require.NoError(t, err) - err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp) - case fw.ProtocolUDP: - ipLayer.Protocol = layers.IPProtocolUDP - udp := &layers.UDP{ - SrcPort: layers.UDPPort(srcPort), - DstPort: layers.UDPPort(dstPort), + if isV6 { + ip6 := &layers.IPv6{ + Version: 6, + HopLimit: 64, + SrcIP: src, + DstIP: dst, } - err = udp.SetNetworkLayerForChecksum(ipLayer) - require.NoError(t, err) - err = gopacket.SerializeLayers(buf, opts, ipLayer, udp) - case fw.ProtocolICMP: - ipLayer.Protocol = layers.IPProtocolICMPv4 - icmp := &layers.ICMPv4{ - TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), + switch proto { + case fw.ProtocolTCP: + ip6.NextHeader = layers.IPProtocolTCP + tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)} + _ = tcp.SetNetworkLayerForChecksum(ip6) + err = gopacket.SerializeLayers(buf, opts, ip6, tcp) + case fw.ProtocolUDP: + ip6.NextHeader = layers.IPProtocolUDP + udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)} + _ = udp.SetNetworkLayerForChecksum(ip6) + err = gopacket.SerializeLayers(buf, opts, ip6, udp) + case fw.ProtocolICMP: + ip6.NextHeader = layers.IPProtocolICMPv6 + icmp := &layers.ICMPv6{ + TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0), + } + _ = icmp.SetNetworkLayerForChecksum(ip6) + err = gopacket.SerializeLayers(buf, opts, ip6, icmp) + default: + err = gopacket.SerializeLayers(buf, opts, ip6) + } + } else { + ip4 := &layers.IPv4{ + Version: 4, + TTL: 64, + SrcIP: src, + DstIP: dst, } - err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp) - default: - err = gopacket.SerializeLayers(buf, opts, ipLayer) + switch proto { + case fw.ProtocolTCP: + ip4.Protocol = layers.IPProtocolTCP + tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)} + _ = tcp.SetNetworkLayerForChecksum(ip4) + err = gopacket.SerializeLayers(buf, opts, ip4, tcp) + case fw.ProtocolUDP: + ip4.Protocol = layers.IPProtocolUDP + udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)} + _ = udp.SetNetworkLayerForChecksum(ip4) + err = gopacket.SerializeLayers(buf, opts, ip4, udp) + case fw.ProtocolICMP: + ip4.Protocol = layers.IPProtocolICMPv4 + icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)} + err = gopacket.SerializeLayers(buf, opts, ip4, icmp) + default: + err = gopacket.SerializeLayers(buf, opts, ip4) + } } require.NoError(t, err) @@ -1498,3 +1681,103 @@ func TestRouteACLSet(t *testing.T) { _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") } + +// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass. +// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only) +// but the ACL decision itself is correct. +func TestRouteACLFilteringIPv6(t *testing.T) { + manager := setupRoutedManager(t, "10.10.0.100/16") + + v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48") + _, err := manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00::/16")}, + fw.Network{Prefix: v6Dst}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{80}}, + fw.ActionAccept, + ) + require.NoError(t, err) + + _, err = manager.AddRouteFiltering( + nil, + []netip.Prefix{netip.MustParsePrefix("fd00::/16")}, + fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")}, + fw.ProtocolALL, + nil, + nil, + fw.ActionDrop, + ) + require.NoError(t, err) + + tests := []struct { + name string + srcIP netip.Addr + dstIP netip.Addr + proto gopacket.LayerType + srcPort uint16 + dstPort uint16 + allowed bool + }{ + { + name: "IPv6 TCP to allowed dest", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 80, + allowed: true, + }, + { + name: "IPv6 TCP wrong port", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 443, + allowed: false, + }, + { + name: "IPv6 UDP not matched by TCP rule", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeUDP, + srcPort: 12345, + dstPort: 80, + allowed: false, + }, + { + name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeICMPv6, + allowed: false, + }, + { + name: "IPv6 to denied subnet", + srcIP: netip.MustParseAddr("fd00::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 80, + allowed: false, + }, + { + name: "IPv6 source outside allowed range", + srcIP: netip.MustParseAddr("fe80::1"), + dstIP: netip.MustParseAddr("fd00:dead:beef::80"), + proto: layers.LayerTypeTCP, + srcPort: 12345, + dstPort: 80, + allowed: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) + require.Equal(t, tc.allowed, pass, "route ACL result mismatch") + }) + } +} diff --git a/client/firewall/uspfilter/filter_routeacl_test.go b/client/firewall/uspfilter/filter_routeacl_test.go index 68572a01c..449554d8b 100644 --- a/client/firewall/uspfilter/filter_routeacl_test.go +++ b/client/firewall/uspfilter/filter_routeacl_test.go @@ -189,21 +189,21 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) { }) // Call blockInvalidRouted directly multiple times - rule1, err := manager.blockInvalidRouted(ifaceMock) + rules1, err := manager.blockInvalidRouted(ifaceMock) require.NoError(t, err) - require.NotNil(t, rule1) + require.NotEmpty(t, rules1) - rule2, err := manager.blockInvalidRouted(ifaceMock) + rules2, err := manager.blockInvalidRouted(ifaceMock) require.NoError(t, err) - require.NotNil(t, rule2) + require.NotEmpty(t, rules2) - rule3, err := manager.blockInvalidRouted(ifaceMock) + rules3, err := manager.blockInvalidRouted(ifaceMock) require.NoError(t, err) - require.NotNil(t, rule3) + require.NotEmpty(t, rules3) - // All should return the same rule - assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule") - assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule") + // All calls should return the same v4 block rule (idempotent install). + assert.Equal(t, rules1[0].ID(), rules2[0].ID(), "Second call should return same v4 rule") + assert.Equal(t, rules2[0].ID(), rules3[0].ID(), "Third call should return same v4 rule") // Should have exactly 1 route rule manager.mutex.RLock() diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 5fb9fef0e..f19c4bb56 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -535,11 +535,16 @@ func TestProcessOutgoingHooks(t *testing.T) { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true return d }, } @@ -638,11 +643,16 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true return d }, } @@ -1048,8 +1058,8 @@ func TestMSSClamping(t *testing.T) { }() require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") - expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) - require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") + require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40") + require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60") err = manager.UpdateLocalIPs() require.NoError(t, err) @@ -1067,7 +1077,7 @@ func TestMSSClamping(t *testing.T) { require.Len(t, d.tcp.Options, 1, "Should have MSS option") require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) - require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") + require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40") }) t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { @@ -1091,7 +1101,7 @@ func TestMSSClamping(t *testing.T) { d := parsePacket(t, packet) require.Len(t, d.tcp.Options, 1, "Should have MSS option") actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) - require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") + require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped") }) t.Run("Non-SYN packet unchanged", func(t *testing.T) { @@ -1263,13 +1273,18 @@ func TestShouldForward(t *testing.T) { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true - err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded) + err = d.decodePacket(buf.Bytes()) require.NoError(t, err) return d @@ -1329,6 +1344,44 @@ func TestShouldForward(t *testing.T) { }, } + // Add IPv6 to the interface and test dual-stack cases + wgIPv6 := netip.MustParseAddr("fd00::1") + otherIPv6 := netip.MustParseAddr("fd00::2") + ifaceMock.AddressFunc = func() wgaddr.Address { + return wgaddr.Address{ + IP: wgIP, + Network: netip.PrefixFrom(wgIP, 24), + IPv6: wgIPv6, + IPv6Net: netip.PrefixFrom(wgIPv6, 64), + } + } + + // Re-create manager to pick up the new address with IPv6 + require.NoError(t, manager.Close(nil)) + manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + + v6Cases := []struct { + name string + dstIP netip.Addr + expected bool + description string + }{ + {"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"}, + {"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"}, + {"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"}, + {"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"}, + } + for _, tt := range v6Cases { + t.Run(tt.name, func(t *testing.T) { + manager.localForwarding = true + manager.netstack = false + decoder := createTCPDecoder(8080) + result := manager.shouldForward(decoder, tt.dstIP) + require.Equal(t, tt.expected, result, tt.description) + }) + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Configure manager diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 96ab89af8..fab776f2a 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -1,7 +1,8 @@ package forwarder import ( - "fmt" + "net" + "strconv" "sync/atomic" wgdevice "golang.zx2c4.com/wireguard/device" @@ -54,16 +55,23 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { var written int for _, pkt := range pkts.AsSlice() { - netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice()) - data := stack.PayloadSince(pkt.NetworkHeader()) if data == nil { continue } - pktBytes := data.AsSlice() + raw := pkt.NetworkHeader().View().AsSlice() + if len(raw) == 0 { + continue + } + var address tcpip.Address + if raw[0]>>4 == 6 { + address = header.IPv6(raw).DestinationAddress() + } else { + address = header.IPv4(raw).DestinationAddress() + } - address := netHeader.DestinationAddress() + pktBytes := data.AsSlice() if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil { e.logger.Error1("CreateOutboundPacket: %v", err) continue @@ -114,5 +122,7 @@ type epID stack.TransportEndpointID func (i epID) String() string { // src and remote is swapped - return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) + return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) + + " → " + + net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort))) } diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 925273f24..6291eb285 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -14,6 +14,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -36,25 +37,31 @@ type Forwarder struct { logger *nblog.Logger flowLogger nftypes.FlowLogger // ruleIdMap is used to store the rule ID for a given connection - ruleIdMap sync.Map - stack *stack.Stack - endpoint *endpoint - udpForwarder *udpForwarder - ctx context.Context - cancel context.CancelFunc - ip tcpip.Address - netstack bool - hasRawICMPAccess bool - pingSemaphore chan struct{} + ruleIdMap sync.Map + stack *stack.Stack + endpoint *endpoint + udpForwarder *udpForwarder + ctx context.Context + cancel context.CancelFunc + ip tcpip.Address + ipv6 tcpip.Address + netstack bool + hasRawICMPAccess bool + hasRawICMPv6Access bool + pingSemaphore chan struct{} } func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, + }, TransportProtocols: []stack.TransportProtocolFactory{ tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, + icmp.NewProtocol6, }, HandleLocal: false, }) @@ -73,7 +80,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow protoAddr := tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), + Address: tcpip.AddrFrom4(iface.Address().IP.As4()), PrefixLen: iface.Address().Network.Bits(), }, } @@ -82,6 +89,19 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return nil, fmt.Errorf("failed to add protocol address: %s", err) } + if v6 := iface.Address().IPv6; v6.IsValid() { + v6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFrom16(v6.As16()), + PrefixLen: iface.Address().IPv6Net.Bits(), + }, + } + if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("add IPv6 protocol address: %s", err) + } + } + defaultSubnet, err := tcpip.NewSubnet( tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), @@ -90,6 +110,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return nil, fmt.Errorf("creating default subnet: %w", err) } + defaultSubnetV6, err := tcpip.NewSubnet( + tcpip.AddrFrom16([16]byte{}), + tcpip.MaskFromBytes(make([]byte, 16)), + ) + if err != nil { + return nil, fmt.Errorf("creating default v6 subnet: %w", err) + } + if err := s.SetPromiscuousMode(nicID, true); err != nil { return nil, fmt.Errorf("set promiscuous mode: %s", err) } @@ -98,10 +126,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow } s.SetRouteTable([]tcpip.Route{ - { - Destination: defaultSubnet, - NIC: nicID, - }, + {Destination: defaultSubnet, NIC: nicID}, + {Destination: defaultSubnetV6, NIC: nicID}, }) ctx, cancel := context.WithCancel(context.Background()) @@ -114,7 +140,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow ctx: ctx, cancel: cancel, netstack: netstack, - ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), + ip: tcpip.AddrFrom4(iface.Address().IP.As4()), + ipv6: addrFromNetipAddr(iface.Address().IPv6), pingSemaphore: make(chan struct{}, 3), } @@ -131,7 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow udpForwarder := udp.NewForwarder(s, f.handleUDP) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) - s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) + // ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's + // network layer. This avoids duplicate echo replies (v4) and the v6 + // auto-reply bug where gVisor responds at the network layer before + // our transport handler fires. f.checkICMPCapability() @@ -150,8 +180,30 @@ func (f *Forwarder) SetCapture(pc PacketCapture) { } func (f *Forwarder) InjectIncomingPacket(payload []byte) error { - if len(payload) < header.IPv4MinimumSize { - return fmt.Errorf("packet too small: %d bytes", len(payload)) + if len(payload) == 0 { + return fmt.Errorf("empty packet") + } + + var protoNum tcpip.NetworkProtocolNumber + switch payload[0] >> 4 { + case 4: + if len(payload) < header.IPv4MinimumSize { + return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload)) + } + if f.handleICMPDirect(payload) { + return nil + } + protoNum = ipv4.ProtocolNumber + case 6: + if len(payload) < header.IPv6MinimumSize { + return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload)) + } + if f.handleICMPDirect(payload) { + return nil + } + protoNum = ipv6.ProtocolNumber + default: + return fmt.Errorf("unknown IP version: %d", payload[0]>>4) } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -160,11 +212,160 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { defer pkt.DecRef() if f.endpoint.dispatcher != nil { - f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) + f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt) } return nil } +// handleICMPDirect intercepts ICMP packets from raw IP payloads before they +// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that +// the existing handlers expect, then dispatches to handleICMP/handleICMPv6. +// This bypasses gVisor's network layer which causes duplicate v4 echo replies +// and auto-replies to all v6 echo requests in promiscuous mode. +// +// Unlike gVisor's network layer, this does not validate ICMP checksums or +// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor. +func parseICMPv4(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) { + if len(payload) < header.IPv4MinimumSize { + return 0, 0, src, dst, false + } + ip := header.IPv4(payload) + if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) { + return 0, 0, src, dst, false + } + if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 { + return 0, 0, src, dst, false + } + ipHdrLen = int(ip.HeaderLength()) + totalLen := int(ip.TotalLength()) + if ipHdrLen < header.IPv4MinimumSize || ipHdrLen > totalLen || totalLen > len(payload) { + return 0, 0, src, dst, false + } + icmpLen = totalLen - ipHdrLen + if icmpLen < header.ICMPv4MinimumSize { + return 0, 0, src, dst, false + } + return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + +func parseICMPv6(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) { + if len(payload) < header.IPv6MinimumSize { + return 0, 0, src, dst, false + } + ip := header.IPv6(payload) + declaredLen := int(ip.PayloadLength()) + hdrEnd := header.IPv6MinimumSize + declaredLen + if hdrEnd > len(payload) { + return 0, 0, src, dst, false + } + icmpStart, ok := skipIPv6ExtensionsToICMPv6(payload, ip.NextHeader(), hdrEnd) + if !ok { + return 0, 0, src, dst, false + } + icmpLen = hdrEnd - icmpStart + if icmpLen < header.ICMPv6MinimumSize { + return 0, 0, src, dst, false + } + return icmpStart, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true +} + +// skipIPv6ExtensionsToICMPv6 walks the IPv6 extension-header chain starting +// after the fixed header. It advances past Hop-by-Hop, Routing, and +// Destination Options headers (which share the NextHeader+ExtLen+6+ExtLen*8 +// layout) and returns the offset of the ICMPv6 payload. Fragment, ESP, AH, +// and unknown identifiers are reported as not handleable so the caller can +// defer to gVisor. +func skipIPv6ExtensionsToICMPv6(payload []byte, next uint8, hdrEnd int) (int, bool) { + off := header.IPv6MinimumSize + for { + if next == uint8(header.ICMPv6ProtocolNumber) { + return off, true + } + if !isWalkableIPv6ExtHdr(next) { + return 0, false + } + newOff, newNext, ok := advanceIPv6ExtHdr(payload, off, hdrEnd) + if !ok { + return 0, false + } + off = newOff + next = newNext + } +} + +func isWalkableIPv6ExtHdr(id uint8) bool { + switch id { + case uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), + uint8(header.IPv6RoutingExtHdrIdentifier), + uint8(header.IPv6DestinationOptionsExtHdrIdentifier): + return true + } + return false +} + +func advanceIPv6ExtHdr(payload []byte, off, hdrEnd int) (int, uint8, bool) { + if off+8 > hdrEnd { + return 0, 0, false + } + extLen := (int(payload[off+1]) + 1) * 8 + if off+extLen > hdrEnd { + return 0, 0, false + } + return off + extLen, payload[off], true +} + +func (f *Forwarder) handleICMPDirect(payload []byte) bool { + if len(payload) == 0 { + return false + } + var ( + ipHdrLen int + icmpLen int + srcAddr tcpip.Address + dstAddr tcpip.Address + ok bool + ) + version := payload[0] >> 4 + switch version { + case 4: + ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv4(payload) + case 6: + ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv6(payload) + } + if !ok { + return false + } + + // Let gVisor handle ICMP destined for our own addresses natively. + // Its network-layer auto-reply is correct and efficient for local traffic. + if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) { + return false + } + + id := stack.TransportEndpointID{ + LocalAddress: dstAddr, + RemoteAddress: srcAddr, + } + + // Trim the buffer to the IP-declared length so gVisor doesn't see padding. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload[:ipHdrLen+icmpLen]), + }) + defer pkt.DecRef() + + if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok { + return false + } + if _, ok := pkt.TransportHeader().Consume(icmpLen); !ok { + return false + } + + if version == 6 { + return f.handleICMPv6(id, pkt) + } + return f.handleICMP(id, pkt) +} + // Stop gracefully shuts down the forwarder func (f *Forwarder) Stop() { f.cancel() @@ -177,11 +378,14 @@ func (f *Forwarder) Stop() { f.stack.Wait() } -func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { +func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr { if f.netstack && f.ip.Equal(addr) { - return net.IPv4(127, 0, 0, 1) + return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } - return addr.AsSlice() + if f.netstack && f.ipv6.Equal(addr) { + return netip.IPv6Loopback() + } + return addrToNetipAddr(addr) } func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { @@ -215,23 +419,50 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe } } +// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating. +func addrFromNetipAddr(addr netip.Addr) tcpip.Address { + if !addr.IsValid() { + return tcpip.Address{} + } + if addr.Is4() { + return tcpip.AddrFrom4(addr.As4()) + } + return tcpip.AddrFrom16(addr.As16()) +} + +// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating. +func addrToNetipAddr(addr tcpip.Address) netip.Addr { + switch addr.Len() { + case 4: + return netip.AddrFrom4(addr.As4()) + case 16: + return netip.AddrFrom16(addr.As16()) + default: + return netip.Addr{} + } +} + // checkICMPCapability tests whether we have raw ICMP socket access at startup. func (f *Forwarder) checkICMPCapability() { + f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger) + f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger) +} + +func probeRawICMP(network, addr string, logger *nblog.Logger) bool { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, addr) if err != nil { - f.hasRawICMPAccess = false - f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback") - return + logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network) + return false } if err := conn.Close(); err != nil { - f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err) + logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err) } - f.hasRawICMPAccess = true - f.logger.Debug("forwarder: Raw ICMP socket access available") + logger.Debug1("forwarder: raw %s socket access available", network) + return true } diff --git a/client/firewall/uspfilter/forwarder/forwarder_test.go b/client/firewall/uspfilter/forwarder/forwarder_test.go new file mode 100644 index 000000000..ad74e8493 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/forwarder_test.go @@ -0,0 +1,162 @@ +package forwarder + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const echoRequestSize = 8 + +func makeIPv6(t *testing.T, src, dst netip.Addr, nextHdr uint8, payload []byte) []byte { + t.Helper() + buf := make([]byte, header.IPv6MinimumSize+len(payload)) + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload)), + TransportProtocol: 0, // overwritten below to allow any value + HopLimit: 64, + SrcAddr: tcpipAddrFromNetip(src), + DstAddr: tcpipAddrFromNetip(dst), + }) + buf[6] = nextHdr + copy(buf[header.IPv6MinimumSize:], payload) + return buf +} + +func tcpipAddrFromNetip(a netip.Addr) tcpip.Address { + b := a.As16() + return tcpip.AddrFrom16(b) +} + +func echoRequest() []byte { + icmp := make([]byte, echoRequestSize) + icmp[0] = uint8(header.ICMPv6EchoRequest) + return icmp +} + +// extHdr builds a generic IPv6 extension header (HBH/Routing/DestOpts) of the +// given total octet length (must be multiple of 8, >= 8) with the given next +// header. +func extHdr(t *testing.T, next uint8, totalLen int) []byte { + t.Helper() + require.GreaterOrEqual(t, totalLen, 8) + require.Equal(t, 0, totalLen%8) + buf := make([]byte, totalLen) + buf[0] = next + buf[1] = uint8(totalLen/8 - 1) + return buf +} + +func TestParseICMPv6_NoExtensions(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), echoRequest()) + + off, icmpLen, _, _, ok := parseICMPv6(pkt) + require.True(t, ok) + assert.Equal(t, header.IPv6MinimumSize, off) + assert.Equal(t, echoRequestSize, icmpLen) +} + +func TestParseICMPv6_SingleExtension(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + hbh := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 8) + payload := append([]byte{}, hbh...) + payload = append(payload, echoRequest()...) + pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload) + + off, icmpLen, _, _, ok := parseICMPv6(pkt) + require.True(t, ok) + assert.Equal(t, header.IPv6MinimumSize+8, off) + assert.Equal(t, echoRequestSize, icmpLen) +} + +func TestParseICMPv6_ChainedExtensions(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + dest := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 16) + rt := extHdr(t, uint8(header.IPv6DestinationOptionsExtHdrIdentifier), 8) + hbh := extHdr(t, uint8(header.IPv6RoutingExtHdrIdentifier), 8) + payload := append(append(append([]byte{}, hbh...), rt...), dest...) + payload = append(payload, echoRequest()...) + pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload) + + off, icmpLen, _, _, ok := parseICMPv6(pkt) + require.True(t, ok) + assert.Equal(t, header.IPv6MinimumSize+8+8+16, off) + assert.Equal(t, echoRequestSize, icmpLen) +} + +func TestParseICMPv6_FragmentDefersToGVisor(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + pkt := makeIPv6(t, src, dst, uint8(header.IPv6FragmentExtHdrIdentifier), make([]byte, 8)) + + _, _, _, _, ok := parseICMPv6(pkt) + assert.False(t, ok) +} + +func TestParseICMPv6_TruncatedExtension(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + // Extension claims 16 bytes but only 8 remain after the IP header. + hbh := []byte{uint8(header.ICMPv6ProtocolNumber), 1, 0, 0, 0, 0, 0, 0} + pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), hbh) + + _, _, _, _, ok := parseICMPv6(pkt) + assert.False(t, ok) +} + +func TestParseICMPv6_TruncatedICMPPayload(t *testing.T) { + src := netip.MustParseAddr("fd00::1") + dst := netip.MustParseAddr("fd00::2") + // PayloadLength claims 8 bytes of ICMPv6 but the buffer only holds 4. + pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), make([]byte, 8)) + pkt = pkt[:header.IPv6MinimumSize+4] + + _, _, _, _, ok := parseICMPv6(pkt) + assert.False(t, ok) +} + +func TestParseICMPv4_RejectsShortIHL(t *testing.T) { + pkt := make([]byte, 28) + pkt[0] = 0x44 // version 4, IHL 4 (16 bytes - below minimum) + pkt[9] = uint8(header.ICMPv4ProtocolNumber) + header.IPv4(pkt).SetTotalLength(28) + + _, _, _, _, ok := parseICMPv4(pkt) + assert.False(t, ok) +} + +func TestParseICMPv4_RejectsTotalLenOverBuffer(t *testing.T) { + pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize) + ip := header.IPv4(pkt) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(pkt) + 16), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 64, + }) + + _, _, _, _, ok := parseICMPv4(pkt) + assert.False(t, ok) +} + +func TestParseICMPv4_RejectsFragment(t *testing.T) { + pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize) + ip := header.IPv4(pkt) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(pkt)), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 64, + Flags: header.IPv4FlagMoreFragments, + }) + + _, _, _, _, ok := parseICMPv4(pkt) + assert.False(t, ok) +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 217423901..3922c2052 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -35,7 +35,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu } icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() - conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond) + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond) if err != nil { f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err) return true @@ -58,7 +58,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI defer func() { <-f.pingSemaphore }() if f.hasRawICMPAccess { - f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes) + f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false) } else { f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes) } @@ -72,18 +72,23 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI // forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection. // The caller is responsible for closing the returned connection. -func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) { +func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) { ctx, cancel := context.WithTimeout(f.ctx, timeout) defer cancel() + network, listenAddr := "ip4:icmp", "0.0.0.0" + if v6 { + network, listenAddr = "ip6:ipv6-icmp", "::" + } + lc := net.ListenConfig{} - conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + conn, err := lc.ListenPacket(ctx, network, listenAddr) if err != nil { return nil, fmt.Errorf("create ICMP socket: %w", err) } dstIP := f.determineDialAddr(id.LocalAddress) - dst := &net.IPAddr{IP: dstIP} + dst := &net.IPAddr{IP: dstIP.AsSlice()} if _, err = conn.WriteTo(payload, dst); err != nil { if closeErr := conn.Close(); closeErr != nil { @@ -98,11 +103,11 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by return conn, nil } -// handleICMPViaSocket handles ICMP echo requests using raw sockets. -func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { +// handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6. +func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) { sendTime := time.Now() - conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second) + conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second) if err != nil { f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err) return @@ -113,16 +118,20 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp } }() - txBytes := f.handleEchoResponse(conn, id) + txBytes := f.handleEchoResponse(conn, id, v6) rtt := time.Since(sendTime).Round(10 * time.Microsecond) - f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", - epID(id), icmpType, icmpCode, rtt) + proto := "ICMP" + if v6 { + proto = "ICMPv6" + } + f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)", + proto, epID(id), icmpType, icmpCode, rtt) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } -func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int { +func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err) return 0 @@ -137,6 +146,19 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn return 0 } + if v6 { + // Recompute checksum: the raw socket response has a checksum computed + // over the real endpoint addresses, but we inject with overlay addresses. + icmpHdr := header.ICMPv6(response[:n]) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr, + Src: id.LocalAddress, + Dst: id.RemoteAddress, + })) + return f.injectICMPv6Reply(id, response[:n]) + } + return f.injectICMPReply(id, response[:n]) } @@ -150,19 +172,23 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T txPackets = 1 } - srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) - dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + srcIp := addrToNetipAddr(id.RemoteAddress) + dstIp := addrToNetipAddr(id.LocalAddress) + + proto := nftypes.ICMP + if srcIp.Is6() { + proto = nftypes.ICMPv6 + } fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, - Protocol: nftypes.ICMP, - // TODO: handle ipv6 - SourceIP: srcIp, - DestIP: dstIp, - ICMPType: icmpType, - ICMPCode: icmpCode, + Protocol: proto, + SourceIP: srcIp, + DestIP: dstIp, + ICMPType: icmpType, + ICMPCode: icmpCode, RxBytes: rxBytes, TxBytes: txBytes, @@ -209,26 +235,164 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } +// handleICMPv6 handles ICMPv6 packets from the network stack. +func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice()) + + flowID := uuid.New() + f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0) + + if icmpHdr.Type() == header.ICMPv6EchoRequest { + return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code())) + } + + // For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting + if !f.hasRawICMPv6Access { + f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id)) + return false + } + + icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() + conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond) + if err != nil { + f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err) + return true + } + if err := conn.Close(); err != nil { + f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err) + } + + return true +} + +// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback. +func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool { + select { + case f.pingSemaphore <- struct{}{}: + icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice() + rxBytes := pkt.Size() + + go func() { + defer func() { <-f.pingSemaphore }() + + if f.hasRawICMPv6Access { + f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true) + } else { + f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes) + } + }() + default: + f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode) + } + return true +} + +// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo. +func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { + ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) + defer cancel() + + dstIP := f.determineDialAddr(id.LocalAddress) + cmd := buildPingCommand(ctx, dstIP, 5*time.Second) + + pingStart := time.Now() + if err := cmd.Run(); err != nil { + f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err) + return + } + rtt := time.Since(pingStart).Round(10 * time.Microsecond) + + f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v", + epID(id), icmpType, icmpCode) + + txBytes := f.synthesizeICMPv6EchoReply(id, icmpData) + + f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)", + epID(id), icmpType, icmpCode, rtt) + + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) +} + +// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back. +func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int { + replyICMP := make([]byte, len(icmpData)) + copy(replyICMP, icmpData) + + replyHdr := header.ICMPv6(replyICMP) + replyHdr.SetType(header.ICMPv6EchoReply) + replyHdr.SetChecksum(0) + // ICMPv6Checksum computes the pseudo-header internally from Src/Dst. + // Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero. + replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: replyHdr, + Src: id.LocalAddress, + Dst: id.RemoteAddress, + })) + + return f.injectICMPv6Reply(id, replyICMP) +} + +// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer. +func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int { + ipHdr := make([]byte, header.IPv6MinimumSize) + ip := header.IPv6(ipHdr) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmpPayload)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 64, + SrcAddr: id.LocalAddress, + DstAddr: id.RemoteAddress, + }) + + fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload)) + fullPacket = append(fullPacket, ipHdr...) + fullPacket = append(fullPacket, icmpPayload...) + + if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil { + f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err) + return 0 + } + + return len(fullPacket) +} + +const ( + pingBin = "ping" + ping6Bin = "ping6" +) + // buildPingCommand creates a platform-specific ping command. -func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd { +// Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6. +func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd { timeoutSec := int(timeout.Seconds()) if timeoutSec < 1 { timeoutSec = 1 } + isV6 := target.Is6() + timeoutStr := fmt.Sprintf("%d", timeoutSec) + switch runtime.GOOS { case "linux", "android": - return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) + return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String()) case "darwin", "ios": - return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) + bin := pingBin + if isV6 { + bin = ping6Bin + } + return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String()) case "freebsd": - return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String()) + return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String()) case "openbsd", "netbsd": - return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String()) + bin := pingBin + if isV6 { + bin = ping6Bin + } + return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String()) case "windows": - return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String()) + return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String()) default: - return exec.CommandContext(ctx, "ping", "-c", "1", target.String()) + return exec.CommandContext(ctx, pingBin, "-c", "1", target.String()) } } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index aef420061..8844463f5 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -2,10 +2,9 @@ package forwarder import ( "context" - "fmt" "io" "net" - "net/netip" + "strconv" "sync" "github.com/google/uuid" @@ -33,7 +32,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { } }() - dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort))) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { @@ -133,15 +132,14 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn } func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { - srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) - dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + srcIp := addrToNetipAddr(id.RemoteAddress) + dstIp := addrToNetipAddr(id.LocalAddress) fields := nftypes.EventFields{ - FlowID: flowID, - Type: typ, - Direction: nftypes.Ingress, - Protocol: nftypes.TCP, - // TODO: handle ipv6 + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: nftypes.TCP, SourceIP: srcIp, DestIP: dstIp, SourcePort: id.RemotePort, diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index f175e275b..c92fa1f32 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -6,7 +6,7 @@ import ( "fmt" "io" "net" - "net/netip" + "strconv" "sync" "sync/atomic" "time" @@ -158,7 +158,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { } }() - dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort))) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) @@ -276,15 +276,14 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack // sendUDPEvent stores flow events for UDP connections func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { - srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) - dstIp := netip.AddrFrom4(id.LocalAddress.As4()) + srcIp := addrToNetipAddr(id.RemoteAddress) + dstIp := addrToNetipAddr(id.LocalAddress) fields := nftypes.EventFields{ - FlowID: flowID, - Type: typ, - Direction: nftypes.Ingress, - Protocol: nftypes.UDP, - // TODO: handle ipv6 + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: nftypes.UDP, SourceIP: srcIp, DestIP: dstIp, SourcePort: id.RemotePort, diff --git a/client/firewall/uspfilter/hooks_filter.go b/client/firewall/uspfilter/hooks_filter.go index 8d3cc0f5c..f3adf5f8b 100644 --- a/client/firewall/uspfilter/hooks_filter.go +++ b/client/firewall/uspfilter/hooks_filter.go @@ -13,7 +13,6 @@ const ( ipv4HeaderMinLen = 20 ipv4ProtoOffset = 9 ipv4FlagsOffset = 6 - ipv4DstOffset = 16 ipProtoUDP = 17 ipProtoTCP = 6 ipv4FragOffMask = 0x1fff diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go index f63fe3e45..b35be56c6 100644 --- a/client/firewall/uspfilter/localip.go +++ b/client/firewall/uspfilter/localip.go @@ -4,89 +4,32 @@ import ( "fmt" "net" "net/netip" - "sync" + "sync/atomic" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" ) -type localIPManager struct { - mu sync.RWMutex - - // fixed-size high array for upper byte of a IPv4 address - ipv4Bitmap [256]*ipv4LowBitmap +// localIPSnapshot is an immutable snapshot of local IP addresses, swapped +// atomically so reads are lock-free. +type localIPSnapshot struct { + ips map[netip.Addr]struct{} } -// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address -type ipv4LowBitmap struct { - bitmap [8192]uint32 +type localIPManager struct { + snapshot atomic.Pointer[localIPSnapshot] } func newLocalIPManager() *localIPManager { - return &localIPManager{} + m := &localIPManager{} + m.snapshot.Store(&localIPSnapshot{ + ips: make(map[netip.Addr]struct{}), + }) + return m } -func (m *localIPManager) setBitmapBit(ip net.IP) { - ipv4 := ip.To4() - if ipv4 == nil { - return - } - high := uint16(ipv4[0]) - low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) - - index := low / 32 - bit := low % 32 - - if m.ipv4Bitmap[high] == nil { - m.ipv4Bitmap[high] = &ipv4LowBitmap{} - } - - m.ipv4Bitmap[high].bitmap[index] |= 1 << bit -} - -func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { - if !ip.Is4() { - return - } - ipv4 := ip.AsSlice() - - high := uint16(ipv4[0]) - low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) - - if bitmap[high] == nil { - bitmap[high] = &ipv4LowBitmap{} - } - - index := low / 32 - bit := low % 32 - bitmap[high].bitmap[index] |= 1 << bit - - if _, exists := ipv4Set[ip]; !exists { - ipv4Set[ip] = struct{}{} - *ipv4Addresses = append(*ipv4Addresses, ip) - } -} - -func (m *localIPManager) checkBitmapBit(ip []byte) bool { - high := uint16(ip[0]) - low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3]) - - if m.ipv4Bitmap[high] == nil { - return false - } - - index := low / 32 - bit := low % 32 - return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 -} - -func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error { - m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) - return nil -} - -func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { +func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) { addrs, err := iface.Addrs() if err != nil { log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) @@ -104,18 +47,19 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv continue } - addr, ok := netip.AddrFromSlice(ip) + parsed, ok := netip.AddrFromSlice(ip) if !ok { log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name) continue } - if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil { - log.Debugf("process IP failed: %v", err) - } + parsed = parsed.Unmap() + ips[parsed] = struct{}{} + *addresses = append(*addresses, parsed) } } +// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically. func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { defer func() { if r := recover(); r != nil { @@ -123,20 +67,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { } }() - var newIPv4Bitmap [256]*ipv4LowBitmap - ipv4Set := make(map[netip.Addr]struct{}) - var ipv4Addresses []netip.Addr + ips := make(map[netip.Addr]struct{}) + var addresses []netip.Addr - // 127.0.0.0/8 - newIPv4Bitmap[127] = &ipv4LowBitmap{} - for i := 0; i < 8192; i++ { - // #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct - newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF - } + // loopback + ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{} + ips[netip.IPv6Loopback()] = struct{}{} if iface != nil { - if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil { - return err + ip := iface.Address().IP + ips[ip] = struct{}{} + addresses = append(addresses, ip) + if v6 := iface.Address().IPv6; v6.IsValid() { + ips[v6] = struct{}{} + addresses = append(addresses, v6) } } @@ -147,25 +91,24 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { // TODO: filter out down interfaces (net.FlagUp). Also handle the reverse // case where an interface comes up between refreshes. for _, intf := range interfaces { - m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) + processInterface(intf, ips, &addresses) } } - m.mu.Lock() - m.ipv4Bitmap = newIPv4Bitmap - m.mu.Unlock() + m.snapshot.Store(&localIPSnapshot{ips: ips}) - log.Debugf("Local IPv4 addresses: %v", ipv4Addresses) + log.Debugf("Local IP addresses: %v", addresses) return nil } +// IsLocalIP checks if the given IP is a local address. Lock-free on the read path. func (m *localIPManager) IsLocalIP(ip netip.Addr) bool { - if !ip.Is4() { - return false + s := m.snapshot.Load() + + if ip.Is4() && ip.As4()[0] == 127 { + return true } - m.mu.RLock() - defer m.mu.RUnlock() - - return m.checkBitmapBit(ip.AsSlice()) + _, found := s.ips[ip] + return found } diff --git a/client/firewall/uspfilter/localip_bench_test.go b/client/firewall/uspfilter/localip_bench_test.go new file mode 100644 index 000000000..14e12bd08 --- /dev/null +++ b/client/firewall/uspfilter/localip_bench_test.go @@ -0,0 +1,72 @@ +package uspfilter + +import ( + "net/netip" + "testing" + + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func setupManager(b *testing.B) *localIPManager { + b.Helper() + m := newLocalIPManager() + mock := &IFaceMock{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + } + }, + } + if err := m.UpdateLocalIPs(mock); err != nil { + b.Fatalf("UpdateLocalIPs: %v", err) + } + return m +} + +func BenchmarkIsLocalIP_v4_hit(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("100.64.0.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_v4_miss(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("8.8.8.8") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_v6_hit(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("fd00::1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_v6_miss(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("2001:db8::1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} + +func BenchmarkIsLocalIP_loopback(b *testing.B) { + m := setupManager(b) + ip := netip.MustParseAddr("127.0.0.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.IsLocalIP(ip) + } +} diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 6653947fa..0dc524c41 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -72,14 +72,45 @@ func TestLocalIPManager(t *testing.T) { expected: false, }, { - name: "IPv6 address", + name: "IPv6 address matches", setupAddr: wgaddr.Address{ - IP: netip.MustParseAddr("fe80::1"), + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + }, + testIP: netip.MustParseAddr("fd00::1"), + expected: true, + }, + { + name: "IPv6 address does not match", + setupAddr: wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + IPv6: netip.MustParseAddr("fd00::1"), + IPv6Net: netip.MustParsePrefix("fd00::/64"), + }, + testIP: netip.MustParseAddr("fd00::99"), + expected: false, + }, + { + name: "No aliasing between similar IPs", + setupAddr: wgaddr.Address{ + IP: netip.MustParseAddr("192.168.1.1"), Network: netip.MustParsePrefix("192.168.1.0/24"), }, - testIP: netip.MustParseAddr("fe80::1"), + testIP: netip.MustParseAddr("192.168.0.17"), expected: false, }, + { + name: "IPv6 loopback", + setupAddr: wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + }, + testIP: netip.MustParseAddr("::1"), + expected: true, + }, } for _, tt := range tests { @@ -171,90 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) { }) } } - -// MapImplementation is a version using map[string]struct{} -type MapImplementation struct { - localIPs map[string]struct{} -} - -func BenchmarkIPChecks(b *testing.B) { - interfaces := make([]net.IP, 16) - for i := range interfaces { - interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) - } - - // Setup bitmap - bitmapManager := newLocalIPManager() - for _, ip := range interfaces[:8] { // Add half of IPs - bitmapManager.setBitmapBit(ip) - } - - // Setup map version - mapManager := &MapImplementation{ - localIPs: make(map[string]struct{}), - } - for _, ip := range interfaces[:8] { - mapManager.localIPs[ip.String()] = struct{}{} - } - - b.Run("Bitmap_Hit", func(b *testing.B) { - ip := interfaces[4] - b.ResetTimer() - for i := 0; i < b.N; i++ { - bitmapManager.checkBitmapBit(ip) - } - }) - - b.Run("Bitmap_Miss", func(b *testing.B) { - ip := interfaces[12] - b.ResetTimer() - for i := 0; i < b.N; i++ { - bitmapManager.checkBitmapBit(ip) - } - }) - - b.Run("Map_Hit", func(b *testing.B) { - ip := interfaces[4] - b.ResetTimer() - for i := 0; i < b.N; i++ { - // nolint:gosimple - _ = mapManager.localIPs[ip.String()] - } - }) - - b.Run("Map_Miss", func(b *testing.B) { - ip := interfaces[12] - b.ResetTimer() - for i := 0; i < b.N; i++ { - // nolint:gosimple - _ = mapManager.localIPs[ip.String()] - } - }) -} - -func BenchmarkWGPosition(b *testing.B) { - wgIP := net.ParseIP("10.10.0.1") - - // Create two managers - one checks WG IP first, other checks it last - b.Run("WG_First", func(b *testing.B) { - bm := newLocalIPManager() - bm.setBitmapBit(wgIP) - b.ResetTimer() - for i := 0; i < b.N; i++ { - bm.checkBitmapBit(wgIP) - } - }) - - b.Run("WG_Last", func(b *testing.B) { - bm := newLocalIPManager() - // Fill with other IPs first - for i := 0; i < 15; i++ { - bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i))) - } - bm.setBitmapBit(wgIP) // Add WG IP last - b.ResetTimer() - for i := 0; i < b.N; i++ { - bm.checkBitmapBit(wgIP) - } - }) -} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 8ed32eb5e..0d411c21e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -13,8 +13,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" ) -var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") - var ( errInvalidIPHeaderLength = errors.New("invalid IP header length") ) @@ -25,10 +23,33 @@ const ( destinationPortOffset = 2 // IP address offsets in IPv4 header - sourceIPOffset = 12 - destinationIPOffset = 16 + ipv4SrcOffset = 12 + ipv4DstOffset = 16 + + // IP address offsets in IPv6 header + ipv6SrcOffset = 8 + ipv6DstOffset = 24 + + // IPv6 fixed header length + ipv6HeaderLen = 40 ) +// ipHeaderLen returns the IP header length based on the decoded layer type. +func ipHeaderLen(d *decoder) (int, error) { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + n := int(d.ip4.IHL) * 4 + if n < 20 { + return 0, errInvalidIPHeaderLength + } + return n, nil + case layers.LayerTypeIPv6: + return ipv6HeaderLen, nil + default: + return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0]) + } +} + // ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { @@ -234,14 +255,13 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) - + _, dstIP := extractPacketIPs(packetData, d) translatedIP, exists := m.getDNATTranslation(dstIP) if !exists { return false } - if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil { m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -256,14 +276,13 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) - + srcIP, _ := extractPacketIPs(packetData, d) originalIP, exists := m.findReverseDNATMapping(srcIP) if !exists { return false } - if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil { m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -272,38 +291,96 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. -func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { +// extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes. +func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]}) + dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]}) + case layers.LayerTypeIPv6: + src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16])) + dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16])) + } + return src, dst +} + +// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error { + hdrLen, err := ipHeaderLen(d) + if err != nil { + return err + } + + switch d.decoded[0] { + case layers.LayerTypeIPv4: + return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource) + case layers.LayerTypeIPv6: + return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource) + default: + return fmt.Errorf("unknown IP layer: %v", d.decoded[0]) + } +} + +func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error { if !newIP.Is4() { - return ErrIPv4Only + return fmt.Errorf("cannot write IPv6 address into IPv4 packet") + } + + offset := ipv4DstOffset + if isSource { + offset = ipv4SrcOffset } var oldIP [4]byte - copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + copy(oldIP[:], packetData[offset:offset+4]) newIPBytes := newIP.As4() + copy(packetData[offset:offset+4], newIPBytes[:]) - copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength - } - + // Recalculate IPv4 header checksum binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) + binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen])) + // Update transport checksums incrementally if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) + m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) + m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) + m.updateICMPChecksum(packetData, hdrLen) } } + return nil +} +func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error { + if !newIP.Is6() { + return fmt.Errorf("cannot write IPv4 address into IPv6 packet") + } + + offset := ipv6DstOffset + if isSource { + offset = ipv6SrcOffset + } + + var oldIP [16]byte + copy(oldIP[:], packetData[offset:offset+16]) + newIPBytes := newIP.As16() + copy(packetData[offset:offset+16], newIPBytes[:]) + + // IPv6 has no header checksum, only update transport checksums + if len(d.decoded) > 1 { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) + case layers.LayerTypeUDP: + m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) + case layers.LayerTypeICMPv6: + // ICMPv6 checksum includes pseudo-header with addresses, use incremental update + m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:]) + } + } return nil } @@ -351,6 +428,20 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } +// updateICMPv6Checksum updates ICMPv6 checksum after address change. +// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies. +func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { + icmpStart := ipHeaderLen + if len(packetData) < icmpStart+4 { + return + } + + checksumOffset := icmpStart + 2 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) +} + // incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -403,14 +494,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { } // addPortRedirection adds a port redirection rule. -func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { +func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error { m.portDNATMutex.Lock() defer m.portDNATMutex.Unlock() rule := portDNATRule{ protocol: protocol, - origPort: sourcePort, - targetPort: targetPort, + origPort: originalPort, + targetPort: translatedPort, targetIP: targetIP, } @@ -422,7 +513,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // TODO: also delegate to nativeFirewall when available for kernel WG mode -func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { var layerType gopacket.LayerType switch protocol { case firewall.ProtocolTCP: @@ -433,16 +524,16 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco return fmt.Errorf("unsupported protocol: %s", protocol) } - return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) + return m.addPortRedirection(localAddr, layerType, originalPort, translatedPort) } // removePortRedirection removes a port redirection rule. -func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { +func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error { m.portDNATMutex.Lock() defer m.portDNATMutex.Unlock() m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { - return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + return rule.protocol == protocol && rule.origPort == originalPort && rule.targetPort == translatedPort && rule.targetIP.Compare(targetIP) == 0 }) if len(m.portDNATRules) == 0 { @@ -453,7 +544,7 @@ func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.L } // RemoveInboundDNAT removes an inbound DNAT rule. -func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { var layerType gopacket.LayerType switch protocol { case firewall.ProtocolTCP: @@ -464,23 +555,23 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return fmt.Errorf("unsupported protocol: %s", protocol) } - return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) + return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort) } // AddOutputDNAT delegates to the native firewall if available. -func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { if m.nativeFirewall == nil { return fmt.Errorf("output DNAT not supported without native firewall") } - return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) + return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort) } // RemoveOutputDNAT delegates to the native firewall if available. -func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { +func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error { if m.nativeFirewall == nil { return nil } - return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) + return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort) } // translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. @@ -532,12 +623,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti // rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength + hdrLen, err := ipHeaderLen(d) + if err != nil { + return err } - tcpStart := ipHeaderLen + tcpStart := hdrLen if len(packetData) < tcpStart+4 { return fmt.Errorf("packet too short for TCP header") } @@ -563,12 +654,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, // rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength + hdrLen, err := ipHeaderLen(d) + if err != nil { + return err } - udpStart := ipHeaderLen + udpStart := hdrLen if len(packetData) < udpStart+8 { return fmt.Errorf("packet too short for UDP header") } diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index d2599e577..1e15c8c0c 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -342,12 +342,17 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) { // Parse the packet fresh each time to get a clean decoder d := &decoder{decoded: []gopacket.LayerType{}} - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true - err = d.parser.DecodeLayers(testPacket, &d.decoded) + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true + err = d.decodePacket(testPacket) assert.NoError(b, err) manager.translateOutboundDNAT(testPacket, d) @@ -371,12 +376,17 @@ func BenchmarkDirectIPExtraction(b *testing.B) { b.Run("decoder_extraction", func(b *testing.B) { // Create decoder once for comparison d := &decoder{decoded: []gopacket.LayerType{}} - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true - err := d.parser.DecodeLayers(packet, &d.decoded) + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true + err := d.decodePacket(packet) assert.NoError(b, err) for i := 0; i < b.N; i++ { diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 50743d006..4598c3901 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -86,13 +86,18 @@ func parsePacket(t testing.TB, packetData []byte) *decoder { d := &decoder{ decoded: []gopacket.LayerType{}, } - d.parser = gopacket.NewDecodingLayerParser( + d.parser4 = gopacket.NewDecodingLayerParser( layers.LayerTypeIPv4, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, ) - d.parser.IgnoreUnsupported = true + d.parser4.IgnoreUnsupported = true + d.parser6 = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv6, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser6.IgnoreUnsupported = true - err := d.parser.DecodeLayers(packetData, &d.decoded) + err := d.decodePacket(packetData) require.NoError(t, err) return d } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index 69c2519bf..696489e95 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -2,7 +2,9 @@ package uspfilter import ( "fmt" + "net" "net/netip" + "strconv" "time" "github.com/google/gopacket" @@ -112,10 +114,13 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, } func (p *PacketBuilder) Build() ([]byte, error) { - ip := p.buildIPLayer() - pktLayers := []gopacket.SerializableLayer{ip} + ipLayer, err := p.buildIPLayer() + if err != nil { + return nil, err + } + pktLayers := []gopacket.SerializableLayer{ipLayer} - transportLayer, err := p.buildTransportLayer(ip) + transportLayer, err := p.buildTransportLayer(ipLayer) if err != nil { return nil, err } @@ -129,30 +134,43 @@ func (p *PacketBuilder) Build() ([]byte, error) { return serializePacket(pktLayers) } -func (p *PacketBuilder) buildIPLayer() *layers.IPv4 { +func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) { + if p.SrcIP.Is4() != p.DstIP.Is4() { + return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP) + } + proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6()) + if p.SrcIP.Is6() { + return &layers.IPv6{ + Version: 6, + HopLimit: 64, + NextHeader: proto, + SrcIP: p.SrcIP.AsSlice(), + DstIP: p.DstIP.AsSlice(), + }, nil + } return &layers.IPv4{ Version: 4, TTL: 64, - Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), + Protocol: proto, SrcIP: p.SrcIP.AsSlice(), DstIP: p.DstIP.AsSlice(), - } + }, nil } -func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { switch p.Protocol { case "tcp": - return p.buildTCPLayer(ip) + return p.buildTCPLayer(ipLayer) case "udp": - return p.buildUDPLayer(ip) + return p.buildUDPLayer(ipLayer) case "icmp": - return p.buildICMPLayer() + return p.buildICMPLayer(ipLayer) default: return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol) } } -func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { tcp := &layers.TCP{ SrcPort: layers.TCPPort(p.SrcPort), DstPort: layers.TCPPort(p.DstPort), @@ -164,24 +182,44 @@ func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableL PSH: p.TCPState != nil && p.TCPState.PSH, URG: p.TCPState != nil && p.TCPState.URG, } - if err := tcp.SetNetworkLayerForChecksum(ip); err != nil { - return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + if err := tcp.SetNetworkLayerForChecksum(nl); err != nil { + return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) + } } return []gopacket.SerializableLayer{tcp}, nil } -func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { udp := &layers.UDP{ SrcPort: layers.UDPPort(p.SrcPort), DstPort: layers.UDPPort(p.DstPort), } - if err := udp.SetNetworkLayerForChecksum(ip); err != nil { - return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + if err := udp.SetNetworkLayerForChecksum(nl); err != nil { + return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) + } } return []gopacket.SerializableLayer{udp}, nil } -func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) { +func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) { + if p.SrcIP.Is6() || p.DstIP.Is6() { + icmp := &layers.ICMPv6{ + TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode), + } + if nl, ok := ipLayer.(gopacket.NetworkLayer); ok { + _ = icmp.SetNetworkLayerForChecksum(nl) + } + if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply { + echo := &layers.ICMPv6Echo{ + Identifier: 1, + SeqNumber: 1, + } + return []gopacket.SerializableLayer{icmp, echo}, nil + } + return []gopacket.SerializableLayer{icmp}, nil + } icmp := &layers.ICMPv4{ TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode), } @@ -204,14 +242,17 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) { return buf.Bytes(), nil } -func getIPProtocolNumber(protocol fw.Protocol) int { +func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol { switch protocol { case fw.ProtocolTCP: - return int(layers.IPProtocolTCP) + return layers.IPProtocolTCP case fw.ProtocolUDP: - return int(layers.IPProtocolUDP) + return layers.IPProtocolUDP case fw.ProtocolICMP: - return int(layers.IPProtocolICMPv4) + if isV6 { + return layers.IPProtocolICMPv6 + } + return layers.IPProtocolICMPv4 default: return 0 } @@ -234,7 +275,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa trace := &PacketTrace{Direction: direction} // Initial packet decoding - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false) return trace } @@ -256,6 +297,8 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa trace.DestinationPort = uint16(d.udp.DstPort) case layers.LayerTypeICMPv4: trace.Protocol = "ICMP" + case layers.LayerTypeICMPv6: + trace.Protocol = "ICMPv6" } trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d", @@ -319,6 +362,13 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string { flags&conntrack.TCPFin != 0) case layers.LayerTypeICMPv4: msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq) + case layers.LayerTypeICMPv6: + var id, seq uint16 + if len(d.icmp6.Payload) >= 4 { + id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1]) + seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3]) + } + msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq) } return msg } @@ -395,7 +445,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n trace.AddResult(StageRouteACL, msg, allowed) if allowed && m.forwarder.Load() != nil { - m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) + m.addForwardingResult(trace, "proxy-remote", net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))), true) } trace.AddResult(StageCompleted, msgProcessingCompleted, allowed) @@ -415,7 +465,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageCompleted, "Packet dropped - decode error", false) return trace } @@ -434,7 +484,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) if portDNATApplied { - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) return true } @@ -444,7 +494,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) if nat1to1Applied { - if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + if err := d.decodePacket(packetData); err != nil { trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) return true } @@ -509,7 +559,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d * return false } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + srcIP, _ := extractPacketIPs(packetData, d) translated := m.translateInboundReverse(packetData, d) if translated { @@ -539,7 +589,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d return false } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + _, dstIP := extractPacketIPs(packetData, d) translated := m.translateOutboundDNAT(packetData, d) if translated { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index e3a96590c..9b070aab8 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, if err != nil { return fmt.Errorf("failed to parse endpoint address: %w", err) } - addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port)) + addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port)) c.activityRecorder.UpsertAddress(peerKey, addrPort) } return nil diff --git a/client/iface/device/adapter.go b/client/iface/device/adapter.go index 6ebc05390..e3caaf930 100644 --- a/client/iface/device/adapter.go +++ b/client/iface/device/adapter.go @@ -2,7 +2,7 @@ package device // TunAdapter is an interface for create tun device from external service type TunAdapter interface { - ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) + ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error) UpdateAddr(address string) error ProtectSocket(fd int32) bool } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 198343fbd..cbe88c10c 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string searchDomainsToString = "" } - fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString) + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) return nil, err diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index acd5f6f11..ac8f8a51b 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -131,23 +131,32 @@ func (t *TunDevice) Device() *device.Device { // assignAddr Adds IP address to the tunnel interface and network route based on the range provided func (t *TunDevice) assignAddr() error { - cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) - if out, err := cmd.CombinedOutput(); err != nil { - log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out) - return err + if out, err := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()).CombinedOutput(); err != nil { + return fmt.Errorf("add v4 address: %s: %w", string(out), err) } - // dummy ipv6 so routing works - cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64") - if out, err := cmd.CombinedOutput(); err != nil { - log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out) + // Assign a dummy link-local so macOS enables IPv6 on the tun device. + // When a real overlay v6 is present, use that instead. + v6Addr := "fe80::/64" + if t.address.HasIPv6() { + v6Addr = t.address.IPv6String() + } + if out, err := exec.Command("ifconfig", t.name, "inet6", v6Addr).CombinedOutput(); err != nil { + log.Warnf("failed to assign IPv6 address %s, continuing v4-only: %s: %v", v6Addr, string(out), err) + t.address.ClearIPv6() } - routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name) - if out, err := routeCmd.CombinedOutput(); err != nil { - log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out) - return err + if out, err := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name).CombinedOutput(); err != nil { + return fmt.Errorf("add route %s via %s: %s: %w", t.address.Network, t.name, string(out), err) } + + if t.address.HasIPv6() { + if out, err := exec.Command("route", "add", "-inet6", "-net", t.address.IPv6Net.String(), "-interface", t.name).CombinedOutput(); err != nil { + log.Warnf("failed to add route %s via %s, continuing v4-only: %s: %v", t.address.IPv6Net, t.name, string(out), err) + t.address.ClearIPv6() + } + } + return nil } diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index aa77cee45..8368c8dce 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -151,8 +151,11 @@ func (t *TunDevice) MTU() uint16 { return t.mtu } -func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error { - // todo implement +// UpdateAddr updates the device address. On iOS the tunnel is managed by the +// NetworkExtension, so we only store the new value. The extension picks up the +// change on the next tunnel reconfiguration. +func (t *TunDevice) UpdateAddr(addr wgaddr.Address) error { + t.address = addr return nil } diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 2a836f846..25c4148a6 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -173,7 +173,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { // assignAddr Adds IP address to the tunnel interface func (t *TunKernelDevice) assignAddr() error { - return t.link.assignAddr(t.address) + return t.link.assignAddr(&t.address) } func (t *TunKernelDevice) GetNet() *netstack.Net { diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index 1a92b148f..b3bce3925 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -3,6 +3,7 @@ package device import ( "errors" "fmt" + "net/netip" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/conn" @@ -63,8 +64,12 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { return nil, fmt.Errorf("last ip: %w", err) } - log.Debugf("netstack using address: %s", t.address.IP) - t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu)) + addresses := []netip.Addr{t.address.IP} + if t.address.HasIPv6() { + addresses = append(addresses, t.address.IPv6) + } + log.Debugf("netstack using addresses: %v", addresses) + t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, addresses, dnsAddr, int(t.mtu)) log.Debugf("netstack using dns address: %s", dnsAddr) tunIface, net, err := t.nsTun.Create() if err != nil { diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 24654fc03..04c265c49 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/wgaddr" ) -type USPDevice struct { +type TunDevice struct { name string address wgaddr.Address port int @@ -30,10 +30,10 @@ type USPDevice struct { configurer WGConfigurer } -func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice { log.Infof("using userspace bind mode") - return &USPDevice{ + return &TunDevice{ name: name, address: address, port: port, @@ -43,7 +43,7 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu } } -func (t *USPDevice) Create() (WGConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { log.Info("create tun interface") tunIface, err := tun.CreateTUN(t.name, int(t.mtu)) if err != nil { @@ -75,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -95,12 +95,12 @@ func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *USPDevice) UpdateAddr(address wgaddr.Address) error { +func (t *TunDevice) UpdateAddr(address wgaddr.Address) error { t.address = address return t.assignAddr() } -func (t *USPDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { t.configurer.Close() } @@ -115,39 +115,39 @@ func (t *USPDevice) Close() error { return nil } -func (t *USPDevice) WgAddress() wgaddr.Address { +func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } -func (t *USPDevice) MTU() uint16 { +func (t *TunDevice) MTU() uint16 { return t.mtu } -func (t *USPDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *USPDevice) FilteredDevice() *FilteredDevice { +func (t *TunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } // Device returns the wireguard device -func (t *USPDevice) Device() *device.Device { +func (t *TunDevice) Device() *device.Device { return t.device } // assignAddr Adds IP address to the tunnel interface -func (t *USPDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { link := newWGLink(t.name) - return link.assignAddr(t.address) + return link.assignAddr(&t.address) } -func (t *USPDevice) GetNet() *netstack.Net { +func (t *TunDevice) GetNet() *netstack.Net { return nil } // GetICEBind returns the ICEBind instance -func (t *USPDevice) GetICEBind() EndpointManager { +func (t *TunDevice) GetICEBind() EndpointManager { return t.iceBind } diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index 96350df8a..f52392fa2 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -87,7 +87,21 @@ func (t *TunDevice) Create() (WGConfigurer, error) { err = nbiface.Set() if err != nil { t.device.Close() - return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err) + return nil, fmt.Errorf("set IPv4 interface MTU: %s", err) + } + + if t.address.HasIPv6() { + nbiface6, err := luid.IPInterface(windows.AF_INET6) + if err != nil { + log.Warnf("failed to get IPv6 interface for MTU, continuing v4-only: %v", err) + t.address.ClearIPv6() + } else { + nbiface6.NLMTU = uint32(t.mtu) + if err := nbiface6.Set(); err != nil { + log.Warnf("failed to set IPv6 interface MTU, continuing v4-only: %v", err) + t.address.ClearIPv6() + } + } } err = t.assignAddr() if err != nil { @@ -178,8 +192,21 @@ func (t *TunDevice) GetInterfaceGUIDString() (string, error) { // assignAddr Adds IP address to the tunnel interface and network route based on the range provided func (t *TunDevice) assignAddr() error { luid := winipcfg.LUID(t.nativeTunDevice.LUID()) - log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) - return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) + + v4Prefix := t.address.Prefix() + if t.address.HasIPv6() { + v6Prefix := t.address.IPv6Prefix() + log.Debugf("adding addresses %s, %s to interface: %s", v4Prefix, v6Prefix, t.name) + if err := luid.SetIPAddresses([]netip.Prefix{v4Prefix, v6Prefix}); err != nil { + log.Warnf("failed to assign dual-stack addresses, retrying v4-only: %v", err) + t.address.ClearIPv6() + return luid.SetIPAddresses([]netip.Prefix{v4Prefix}) + } + return nil + } + + log.Debugf("adding address %s to interface: %s", v4Prefix, t.name) + return luid.SetIPAddresses([]netip.Prefix{v4Prefix}) } func (t *TunDevice) GetNet() *netstack.Net { diff --git a/client/iface/device/kernel_module.go b/client/iface/device/kernel_module.go deleted file mode 100644 index 1bdd6f7c6..000000000 --- a/client/iface/device/kernel_module.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build (!linux && !freebsd) || android - -package device - -// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only) -func WireGuardModuleIsLoaded() bool { - return false -} diff --git a/client/iface/device/kernel_module_freebsd.go b/client/iface/device/kernel_module_freebsd.go deleted file mode 100644 index dd6c8b408..000000000 --- a/client/iface/device/kernel_module_freebsd.go +++ /dev/null @@ -1,18 +0,0 @@ -package device - -// WireGuardModuleIsLoaded check if kernel support wireguard -func WireGuardModuleIsLoaded() bool { - // Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd) - // we are currently do not use it, since it is required to add wireguard kernel support to - // - https://github.com/netbirdio/netbird/tree/main/sharedsock - // - https://github.com/mdlayher/socket - // TODO: implement kernel space - return false -} - -// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it -func ModuleTunIsLoaded() bool { - // Assume tun supported by freebsd kernel by default - // TODO: implement check for module loaded in kernel or build-it - return true -} diff --git a/client/iface/device/kernel_module_nonlinux.go b/client/iface/device/kernel_module_nonlinux.go new file mode 100644 index 000000000..58d97080b --- /dev/null +++ b/client/iface/device/kernel_module_nonlinux.go @@ -0,0 +1,13 @@ +//go:build !linux || android + +package device + +// WireGuardModuleIsLoaded reports whether the kernel WireGuard module is available. +func WireGuardModuleIsLoaded() bool { + return false +} + +// ModuleTunIsLoaded reports whether the tun device is available. +func ModuleTunIsLoaded() bool { + return true +} diff --git a/client/iface/device/wg_link_freebsd.go b/client/iface/device/wg_link_freebsd.go index 1b06e0e15..87df89183 100644 --- a/client/iface/device/wg_link_freebsd.go +++ b/client/iface/device/wg_link_freebsd.go @@ -2,6 +2,7 @@ package device import ( "fmt" + "os/exec" log "github.com/sirupsen/logrus" @@ -57,32 +58,32 @@ func (l *wgLink) up() error { return nil } -func (l *wgLink) assignAddr(address wgaddr.Address) error { +func (l *wgLink) assignAddr(address *wgaddr.Address) error { link, err := freebsd.LinkByName(l.name) if err != nil { return fmt.Errorf("link by name: %w", err) } - ip := address.IP.String() - - // Convert prefix length to hex netmask prefixLen := address.Network.Bits() - if !address.IP.Is4() { - return fmt.Errorf("IPv6 not supported for interface assignment") - } - maskBits := uint32(0xffffffff) << (32 - prefixLen) mask := fmt.Sprintf("0x%08x", maskBits) - log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) + log.Infof("assign addr %s mask %s to %s interface", address.IP, mask, l.name) - err = link.AssignAddr(ip, mask) - if err != nil { + if err := link.AssignAddr(address.IP.String(), mask); err != nil { return fmt.Errorf("assign addr: %w", err) } - err = link.Up() - if err != nil { + if address.HasIPv6() { + log.Infof("assign IPv6 addr %s to %s interface", address.IPv6String(), l.name) + cmd := exec.Command("ifconfig", l.name, "inet6", address.IPv6String()) + if out, err := cmd.CombinedOutput(); err != nil { + log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %s: %v", address.IPv6String(), l.name, string(out), err) + address.ClearIPv6() + } + } + + if err := link.Up(); err != nil { return fmt.Errorf("up: %w", err) } diff --git a/client/iface/device/wg_link_linux.go b/client/iface/device/wg_link_linux.go index d941cd022..6a02cb356 100644 --- a/client/iface/device/wg_link_linux.go +++ b/client/iface/device/wg_link_linux.go @@ -4,6 +4,8 @@ package device import ( "fmt" + "net" + "net/netip" "os" log "github.com/sirupsen/logrus" @@ -92,7 +94,7 @@ func (l *wgLink) up() error { return nil } -func (l *wgLink) assignAddr(address wgaddr.Address) error { +func (l *wgLink) assignAddr(address *wgaddr.Address) error { //delete existing addresses list, err := netlink.AddrList(l, 0) if err != nil { @@ -110,20 +112,16 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error { } name := l.attrs.Name - addrStr := address.String() - log.Debugf("adding address %s to interface: %s", addrStr, name) - - addr, err := netlink.ParseAddr(addrStr) - if err != nil { - return fmt.Errorf("parse addr: %w", err) + if err := l.addAddr(name, address.Prefix()); err != nil { + return err } - err = netlink.AddrAdd(l, addr) - if os.IsExist(err) { - log.Infof("interface %s already has the address: %s", name, addrStr) - } else if err != nil { - return fmt.Errorf("add addr: %w", err) + if address.HasIPv6() { + if err := l.addAddr(name, address.IPv6Prefix()); err != nil { + log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %v", address.IPv6Prefix(), name, err) + address.ClearIPv6() + } } // On linux, the link must be brought up @@ -133,3 +131,22 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error { return nil } + +func (l *wgLink) addAddr(ifaceName string, prefix netip.Prefix) error { + log.Debugf("adding address %s to interface: %s", prefix, ifaceName) + + addr := &netlink.Addr{ + IPNet: &net.IPNet{ + IP: prefix.Addr().AsSlice(), + Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), + }, + } + + if err := netlink.AddrAdd(l, addr); os.IsExist(err) { + log.Infof("interface %s already has the address: %s", ifaceName, prefix) + } else if err != nil { + return fmt.Errorf("add addr %s: %w", prefix, err) + } + + return nil +} diff --git a/client/iface/iface.go b/client/iface/iface.go index 655dd1682..78c5080e7 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -57,7 +57,7 @@ type wgProxyFactory interface { type WGIFaceOpts struct { IFaceName string - Address string + Address wgaddr.Address WGPort int WGPrivKey string MTU uint16 @@ -141,16 +141,11 @@ func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { } // UpdateAddr updates address of the interface -func (w *WGIface) UpdateAddr(newAddr string) error { +func (w *WGIface) UpdateAddr(newAddr wgaddr.Address) error { w.mu.Lock() defer w.mu.Unlock() - addr, err := wgaddr.ParseWGAddress(newAddr) - if err != nil { - return err - } - - return w.tun.UpdateAddr(addr) + return w.tun.UpdateAddr(newAddr) } // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new.go similarity index 50% rename from client/iface/iface_new_windows.go rename to client/iface/iface_new.go index dfd9028e7..28f350e3f 100644 --- a/client/iface/iface_new_windows.go +++ b/client/iface/iface_new.go @@ -1,33 +1,28 @@ +//go:build !linux && !ios && !android && !js + package iface import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" - wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) var tun WGTunDevice if netstack.IsEnabled() { - tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + tun = device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) } else { - tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + tun = device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) } - wgIFace := &WGIface{ + return &WGIface{ userspaceBind: true, tun: tun, wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), - } - return wgIFace, nil - + }, nil } diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 3b68f63f2..e28dcc0de 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -4,23 +4,17 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) if netstack.IsEnabled() { wgIFace := &WGIface{ userspaceBind: true, - tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), + tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil @@ -28,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, - tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), + tun: device.NewTunDevice(opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go deleted file mode 100644 index 9f21ec950..000000000 --- a/client/iface/iface_new_darwin.go +++ /dev/null @@ -1,35 +0,0 @@ -//go:build !ios - -package iface - -import ( - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" - "github.com/netbirdio/netbird/client/iface/wgproxy" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - - var tun WGTunDevice - if netstack.IsEnabled() { - tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) - } else { - tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) - } - - wgIFace := &WGIface{ - userspaceBind: true, - tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), - } - return wgIFace, nil -} diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go deleted file mode 100644 index a342bd579..000000000 --- a/client/iface/iface_new_freebsd.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:build freebsd - -package iface - -import ( - "fmt" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" - "github.com/netbirdio/netbird/client/iface/wgproxy" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{} - - if netstack.IsEnabled() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) - return wgIFace, nil - } - - if device.ModuleTunIsLoaded() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) - return wgIFace, nil - } - - return nil, fmt.Errorf("couldn't check or load tun module") -} diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go index 5d6a32e39..41e0022b2 100644 --- a/client/iface/iface_new_ios.go +++ b/client/iface/iface_new_ios.go @@ -5,21 +5,15 @@ package iface import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) wgIFace := &WGIface{ - tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), + tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), userspaceBind: true, wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } diff --git a/client/iface/iface_new_js.go b/client/iface/iface_new_js.go index ad913ab04..9f7a3ba62 100644 --- a/client/iface/iface_new_js.go +++ b/client/iface/iface_new_js.go @@ -4,21 +4,15 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode) func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - relayBind := bind.NewRelayBindJS() wgIface := &WGIface{ - tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()), + tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()), userspaceBind: true, wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU), } diff --git a/client/iface/iface_new_linux.go b/client/iface/iface_new_linux.go index d84035403..65ce67e88 100644 --- a/client/iface/iface_new_linux.go +++ b/client/iface/iface_new_linux.go @@ -3,44 +3,40 @@ package iface import ( - "fmt" + "errors" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{} - if netstack.IsEnabled() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) - return wgIFace, nil + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) + return &WGIface{ + tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), + }, nil } if device.WireGuardModuleIsLoaded() { - wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet) - wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU) - return wgIFace, nil - } - if device.ModuleTunIsLoaded() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) - return wgIFace, nil + return &WGIface{ + tun: device.NewKernelDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet), + wgProxyFactory: wgproxy.NewKernelFactory(opts.WGPort, opts.MTU), + }, nil } - return nil, fmt.Errorf("couldn't check or load tun module") + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU) + return &WGIface{ + tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), + }, nil + } + + return nil, errors.New("tun module not available") } diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index 6bbfeaa63..dbeb69bc6 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -16,6 +16,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/stdnet" ) @@ -48,7 +49,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: addr, + Address: wgaddr.MustParseWGAddress(addr), WGPort: wgPort, WGPrivKey: key, MTU: DefaultMTU, @@ -84,7 +85,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { //update WireGuard address addr = "100.64.0.2/8" - err = iface.UpdateAddr(addr) + err = iface.UpdateAddr(wgaddr.MustParseWGAddress(addr)) if err != nil { t.Fatal(err) } @@ -130,7 +131,7 @@ func Test_CreateInterface(t *testing.T) { } opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: 33100, WGPrivKey: key, MTU: DefaultMTU, @@ -174,7 +175,7 @@ func Test_Close(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: wgPort, WGPrivKey: key, MTU: DefaultMTU, @@ -219,7 +220,7 @@ func TestRecreation(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: wgPort, WGPrivKey: key, MTU: DefaultMTU, @@ -291,7 +292,7 @@ func Test_ConfigureInterface(t *testing.T) { } opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: wgPort, WGPrivKey: key, MTU: DefaultMTU, @@ -347,7 +348,7 @@ func Test_UpdatePeer(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: 33100, WGPrivKey: key, MTU: DefaultMTU, @@ -417,7 +418,7 @@ func Test_RemovePeer(t *testing.T) { opts := WGIFaceOpts{ IFaceName: ifaceName, - Address: wgIP, + Address: wgaddr.MustParseWGAddress(wgIP), WGPort: 33100, WGPrivKey: key, MTU: DefaultMTU, @@ -482,7 +483,7 @@ func Test_ConnectPeers(t *testing.T) { optsPeer1 := WGIFaceOpts{ IFaceName: peer1ifaceName, - Address: peer1wgIP.String(), + Address: wgaddr.MustParseWGAddress(peer1wgIP.String()), WGPort: peer1wgPort, WGPrivKey: peer1Key.String(), MTU: DefaultMTU, @@ -522,7 +523,7 @@ func Test_ConnectPeers(t *testing.T) { optsPeer2 := WGIFaceOpts{ IFaceName: peer2ifaceName, - Address: peer2wgIP.String(), + Address: wgaddr.MustParseWGAddress(peer2wgIP.String()), WGPort: peer2wgPort, WGPrivKey: peer2Key.String(), MTU: DefaultMTU, diff --git a/client/iface/netstack/tun.go b/client/iface/netstack/tun.go index 346ae29ec..8c7526bbb 100644 --- a/client/iface/netstack/tun.go +++ b/client/iface/netstack/tun.go @@ -13,7 +13,7 @@ import ( const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" type NetStackTun struct { //nolint:revive - address netip.Addr + addresses []netip.Addr dnsAddress netip.Addr mtu int listenAddress string @@ -22,9 +22,9 @@ type NetStackTun struct { //nolint:revive tundev tun.Device } -func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun { +func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun { return &NetStackTun{ - address: address, + addresses: addresses, dnsAddress: dnsAddress, mtu: mtu, listenAddress: listenAddress, @@ -33,7 +33,7 @@ func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.A func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { nsTunDev, tunNet, err := netstack.CreateNetTUN( - []netip.Addr{t.address}, + t.addresses, []netip.Addr{t.dnsAddress}, t.mtu) if err != nil { diff --git a/client/iface/wgaddr/address.go b/client/iface/wgaddr/address.go index 078f8be95..43d1ec9aa 100644 --- a/client/iface/wgaddr/address.go +++ b/client/iface/wgaddr/address.go @@ -3,12 +3,18 @@ package wgaddr import ( "fmt" "net/netip" + + "github.com/netbirdio/netbird/shared/netiputil" ) // Address WireGuard parsed address type Address struct { IP netip.Addr Network netip.Prefix + + // IPv6 overlay address, if assigned. + IPv6 netip.Addr + IPv6Net netip.Prefix } // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address @@ -23,6 +29,60 @@ func ParseWGAddress(address string) (Address, error) { }, nil } -func (addr Address) String() string { - return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits()) +// HasIPv6 reports whether a v6 overlay address is assigned. +func (addr Address) HasIPv6() bool { + return addr.IPv6.IsValid() +} + +func (addr Address) String() string { + return addr.Prefix().String() +} + +// IPv6String returns the v6 address in CIDR notation, or empty string if none. +func (addr Address) IPv6String() string { + if !addr.HasIPv6() { + return "" + } + return addr.IPv6Prefix().String() +} + +// Prefix returns the v4 host address with its network prefix length (e.g. 100.64.0.1/16). +func (addr Address) Prefix() netip.Prefix { + return netip.PrefixFrom(addr.IP, addr.Network.Bits()) +} + +// IPv6Prefix returns the v6 host address with its network prefix length, or a zero prefix if none. +func (addr Address) IPv6Prefix() netip.Prefix { + if !addr.HasIPv6() { + return netip.Prefix{} + } + return netip.PrefixFrom(addr.IPv6, addr.IPv6Net.Bits()) +} + +// SetIPv6FromCompact decodes a compact prefix (5 or 17 bytes) and sets the IPv6 fields. +// Returns an error if the bytes are invalid. A nil or empty input is a no-op. +// +//nolint:recvcheck +func (addr *Address) SetIPv6FromCompact(raw []byte) error { + if len(raw) == 0 { + return nil + } + prefix, err := netiputil.DecodePrefix(raw) + if err != nil { + return fmt.Errorf("decode v6 overlay address: %w", err) + } + if !prefix.Addr().Is6() { + return fmt.Errorf("expected IPv6 address, got %s", prefix.Addr()) + } + addr.IPv6 = prefix.Addr() + addr.IPv6Net = prefix.Masked() + return nil +} + +// ClearIPv6 removes the IPv6 overlay address, leaving only v4. +// +//nolint:recvcheck +func (addr *Address) ClearIPv6() { + addr.IPv6 = netip.Addr{} + addr.IPv6Net = netip.Prefix{} } diff --git a/client/iface/wgaddr/address_test_helpers.go b/client/iface/wgaddr/address_test_helpers.go new file mode 100644 index 000000000..87403e789 --- /dev/null +++ b/client/iface/wgaddr/address_test_helpers.go @@ -0,0 +1,10 @@ +package wgaddr + +// MustParseWGAddress parses and returns a WG Address, panicking on error. +func MustParseWGAddress(address string) Address { + a, err := ParseWGAddress(address) + if err != nil { + panic(err) + } + return a +} diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 9ac3ea6df..be6f3806e 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -6,7 +6,7 @@ import ( "fmt" "net" "net/netip" - "strings" + "sync" log "github.com/sirupsen/logrus" @@ -196,18 +196,25 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } } -// fakeAddress returns a fake address that is used to as an identifier for the peer. -// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. +// fakeAddress returns a fake address that is used as an identifier for the peer. +// The fake address is in the format of 127.1.x.x where x.x is derived from the +// last two bytes of the peer address (works for both IPv4 and IPv6). func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { - octets := strings.Split(peerAddress.IP.String(), ".") - if len(octets) != 4 { - return nil, fmt.Errorf("invalid IP format") + if peerAddress == nil { + return nil, fmt.Errorf("nil peer address") + } + if peerAddress.Port < 0 || peerAddress.Port > 65535 { + return nil, fmt.Errorf("invalid UDP port: %d", peerAddress.Port) } - fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])) - if err != nil { - return nil, fmt.Errorf("parse new IP: %w", err) + addr, ok := netip.AddrFromSlice(peerAddress.IP) + if !ok { + return nil, fmt.Errorf("invalid IP format") } + addr = addr.Unmap() + + raw := addr.As16() + fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]}) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) return &netipAddr, nil diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index dd6f9479a..c54a3e897 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "errors" "fmt" - "net" "net/netip" "strconv" "sync" @@ -19,6 +18,7 @@ import ( "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" ) var ErrSourceRangesEmpty = errors.New("sources range is empty") @@ -105,6 +105,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { newRulePairs := make(map[id.RuleID][]firewall.Rule) ipsetByRuleSelectors := make(map[string]string) + // TODO: deny rules should be fatal: if a deny rule fails to apply, we must + // roll back all allow rules to avoid a fail-open where allowed traffic bypasses + // the missing deny. Currently we accumulate errors and continue. + var merr *multierror.Error for _, r := range rules { // if this rule is member of rule selection with more than DefaultIPsCountForSet // it's IP address can be used in the ipset for firewall manager which supports it @@ -117,9 +121,8 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { } pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) if err != nil { - log.Errorf("failed to apply firewall rule: %+v, %v", r, err) - d.rollBack(newRulePairs) - break + merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err)) + continue } if len(rulePair) > 0 { d.peerRulesPairs[pairID] = rulePair @@ -127,6 +130,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { } } + if merr != nil { + log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr)) + } + for pairID, rules := range d.peerRulesPairs { if _, ok := newRulePairs[pairID]; !ok { for _, rule := range rules { @@ -216,9 +223,9 @@ func (d *DefaultManager) protoRuleToFirewallRule( r *mgmProto.FirewallRule, ipsetName string, ) (id.RuleID, []firewall.Rule, error) { - ip := net.ParseIP(r.PeerIP) - if ip == nil { - return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") + ip, err := extractRuleIP(r) + if err != nil { + return "", nil, err } protocol, err := convertToFirewallProtocol(r.Protocol) @@ -289,13 +296,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool { func (d *DefaultManager) addInRules( id []byte, - ip net.IP, + ip netip.Addr, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, ipsetName string, ) ([]firewall.Rule, error) { - rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName) + rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName) if err != nil { return nil, fmt.Errorf("add firewall rule: %w", err) } @@ -305,7 +312,7 @@ func (d *DefaultManager) addInRules( func (d *DefaultManager) addOutRules( id []byte, - ip net.IP, + ip netip.Addr, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, @@ -315,7 +322,7 @@ func (d *DefaultManager) addOutRules( return nil, nil } - rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName) + rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName) if err != nil { return nil, fmt.Errorf("add firewall rule: %w", err) } @@ -323,9 +330,9 @@ func (d *DefaultManager) addOutRules( return rule, nil } -// getPeerRuleID() returns unique ID for the rule based on its parameters. +// getPeerRuleID returns unique ID for the rule based on its parameters. func (d *DefaultManager) getPeerRuleID( - ip net.IP, + ip netip.Addr, proto firewall.Protocol, direction int, port *firewall.Port, @@ -344,15 +351,25 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) } -func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { - log.Debugf("rollback ACL to previous state") - for _, rules := range newRulePairs { - for _, rule := range rules { - if err := d.firewall.DeletePeerRule(rule); err != nil { - log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err) - } + +// extractRuleIP extracts the peer IP from a firewall rule. +// If sourcePrefixes is populated (new management), decode the first entry and use its address. +// Otherwise fall back to the deprecated PeerIP string field (old management). +func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) { + if len(r.SourcePrefixes) > 0 { + addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0]) + if err != nil { + return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err) } + return addr.Unmap(), nil } + + //nolint:staticcheck // PeerIP used for backward compatibility with old management + addr, err := netip.ParseAddr(r.PeerIP) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule") + } + return addr.Unmap(), nil } func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { diff --git a/client/internal/auth/auth.go b/client/internal/auth/auth.go index bdfd07430..afc8ee77f 100644 --- a/client/internal/auth/auth.go +++ b/client/internal/auth/auth.go @@ -321,6 +321,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) { a.config.DisableFirewall, a.config.BlockLANAccess, a.config.BlockInbound, + a.config.DisableIPv6, a.config.LazyConnectionEnabled, a.config.EnableSSHRoot, a.config.EnableSSHSFTP, diff --git a/client/internal/connect.go b/client/internal/connect.go index 72e096a80..8c0e9b1ba 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -14,10 +14,13 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" @@ -536,9 +539,20 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf if config.NetworkMonitor != nil { nm = *config.NetworkMonitor } + wgAddr, err := wgaddr.ParseWGAddress(peerConfig.Address) + if err != nil { + return nil, fmt.Errorf("parse overlay address %q: %w", peerConfig.Address, err) + } + + if !config.DisableIPv6 { + if err := wgAddr.SetIPv6FromCompact(peerConfig.GetAddressV6()); err != nil { + log.Warn(err) + } + } + engineConf := &EngineConfig{ WgIfaceName: config.WgIface, - WgAddr: peerConfig.Address, + WgAddr: wgAddr, IFaceBlackList: config.IFaceBlackList, DisableIPv6Discovery: config.DisableIPv6Discovery, WgPrivateKey: key, @@ -563,6 +577,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf DisableFirewall: config.DisableFirewall, BlockLANAccess: config.BlockLANAccess, BlockInbound: config.BlockInbound, + DisableIPv6: config.DisableIPv6, LazyConnectionEnabled: config.LazyConnectionEnabled, @@ -637,6 +652,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config.DisableFirewall, config.BlockLANAccess, config.BlockInbound, + config.DisableIPv6, config.LazyConnectionEnabled, config.EnableSSHRoot, config.EnableSSHSFTP, diff --git a/client/internal/connect_android_default.go b/client/internal/connect_android_default.go index 190341c4a..b05e91fec 100644 --- a/client/internal/connect_android_default.go +++ b/client/internal/connect_android_default.go @@ -40,6 +40,10 @@ func (noopNetworkChangeListener) SetInterfaceIP(string) { // network stack, not by OS-level interface configuration. } +func (noopNetworkChangeListener) SetInterfaceIPv6(string) { + // No-op: same as SetInterfaceIP, IPv6 overlay is managed by userspace stack. +} + // noopDnsReadyListener is a stub for embed.Client on Android. // DNS readiness notifications are not needed in netstack/embed mode // since system DNS is disabled and DNS resolution happens externally. diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 0a12a5326..9c50f02b3 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -31,6 +31,7 @@ import ( "github.com/netbirdio/netbird/client/internal/updater/installer" nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" ) const readmeContent = `Netbird debug bundle @@ -624,6 +625,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall)) configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess)) configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound)) + configContent.WriteString(fmt.Sprintf("DisableIPv6: %v\n", g.internalConfig.DisableIPv6)) if g.internalConfig.DisableNotifications != nil { configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications)) @@ -1294,6 +1296,21 @@ func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anon config.Address = anonymizer.AnonymizeIP(addr).String() } + if len(config.GetAddressV6()) > 0 { + v6Prefix, err := netiputil.DecodePrefix(config.GetAddressV6()) + if err != nil { + config.AddressV6 = nil + } else { + anonV6 := anonymizer.AnonymizeIP(v6Prefix.Addr()) + b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonV6, v6Prefix.Bits())) + if err != nil { + config.AddressV6 = nil + } else { + config.AddressV6 = b + } + } + } + anonymizeSSHConfig(config.SshConfig) config.Dns = anonymizer.AnonymizeString(config.Dns) @@ -1396,8 +1413,20 @@ func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.An return } + //nolint:staticcheck // PeerIP used for backward compatibility if addr, err := netip.ParseAddr(rule.PeerIP); err == nil { - rule.PeerIP = anonymizer.AnonymizeIP(addr).String() + rule.PeerIP = anonymizer.AnonymizeIP(addr).String() //nolint:staticcheck + } + + for i, raw := range rule.GetSourcePrefixes() { + p, err := netiputil.DecodePrefix(raw) + if err != nil { + continue + } + anonAddr := anonymizer.AnonymizeIP(p.Addr()) + if b, err := netiputil.EncodePrefix(netip.PrefixFrom(anonAddr, p.Bits())); err == nil { + rule.SourcePrefixes[i] = b + } } } diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index 05d51e593..39b972244 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -5,6 +5,7 @@ import ( "bytes" "encoding/json" "net" + "net/netip" "net/url" "os" "path/filepath" @@ -21,8 +22,16 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" ) +func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte { + t.Helper() + b, err := netiputil.EncodePrefix(p) + require.NoError(t, err) + return b +} + func TestAnonymizeStateFile(t *testing.T) { testState := map[string]json.RawMessage{ "null_state": json.RawMessage("null"), @@ -173,7 +182,7 @@ func TestAnonymizeStateFile(t *testing.T) { assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"]) - assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged + assert.NotEqual(t, "fd00::1", state["private_ipv6"]) // ULA IPv6 anonymized (global ID is a fingerprint) assert.NotEqual(t, "test.example.com", state["domain"]) assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain")) assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged @@ -277,11 +286,13 @@ func mustMarshal(v any) json.RawMessage { } func TestAnonymizeNetworkMap(t *testing.T) { + origV6Prefix := netip.MustParsePrefix("2001:db8:abcd::5/64") networkMap := &mgmProto.NetworkMap{ PeerConfig: &mgmProto.PeerConfig{ - Address: "203.0.113.5", - Dns: "1.2.3.4", - Fqdn: "peer1.corp.example.com", + Address: "203.0.113.5", + AddressV6: mustEncodePrefix(t, origV6Prefix), + Dns: "1.2.3.4", + Fqdn: "peer1.corp.example.com", SshConfig: &mgmProto.SSHConfig{ SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."), }, @@ -355,6 +366,12 @@ func TestAnonymizeNetworkMap(t *testing.T) { require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn) require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain")) + // Verify AddressV6 is anonymized but preserves prefix length + anonV6Prefix, err := netiputil.DecodePrefix(peerCfg.AddressV6) + require.NoError(t, err) + assert.Equal(t, origV6Prefix.Bits(), anonV6Prefix.Bits(), "prefix length must be preserved") + assert.NotEqual(t, origV6Prefix.Addr(), anonV6Prefix.Addr(), "IPv6 address must be anonymized") + // Verify SSH key is replaced require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey) @@ -660,8 +677,6 @@ func isInCGNATRange(ip net.IP) bool { } func TestAnonymizeFirewallRules(t *testing.T) { - // TODO: Add ipv6 - // Example iptables-save output iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024 *filter @@ -697,17 +712,31 @@ Chain FORWARD (policy ACCEPT 0 packets, 0 bytes) Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) pkts bytes target prot opt in out source destination` - // Example nftables output + // Example ip6tables-save output + ip6tablesSave := `# Generated by ip6tables-save v1.8.7 on Thu Dec 19 10:00:00 2024 +*filter +:INPUT ACCEPT [0:0] +:FORWARD ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +-A INPUT -s fd00:1234::1/128 -j ACCEPT +-A INPUT -s 2607:f8b0:4005::1/128 -j DROP +-A FORWARD -s 2001:db8::/32 -d 2607:f8b0:4005::200e/128 -j ACCEPT +COMMIT` + + // Example nftables output with IPv6 nftablesRules := `table inet filter { chain input { type filter hook input priority filter; policy accept; ip saddr 192.168.1.1 accept ip saddr 44.192.140.1 drop + ip6 saddr 2607:f8b0:4005::1 drop + ip6 saddr fd00:1234::1 accept } chain forward { type filter hook forward priority filter; policy accept; ip saddr 10.0.0.0/8 drop ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept + ip6 saddr 2001:db8::/32 ip6 daddr 2607:f8b0:4005::200e/128 accept } }` @@ -770,6 +799,37 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) assert.Contains(t, anonNftables, "table inet filter {") assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") + + // IPv6 public addresses in nftables should be anonymized + assert.NotContains(t, anonNftables, "2607:f8b0:4005::1") + assert.NotContains(t, anonNftables, "2607:f8b0:4005::200e") + assert.NotContains(t, anonNftables, "2001:db8::") + assert.Contains(t, anonNftables, "2001:db8:ffff::") // Default anonymous v6 range + + // ULA addresses in nftables should be anonymized (global ID is a fingerprint) + assert.NotContains(t, anonNftables, "fd00:1234::1") + + // IPv6 nftables structure preserved + assert.Contains(t, anonNftables, "ip6 saddr") + assert.Contains(t, anonNftables, "ip6 daddr") + + // Test ip6tables-save anonymization + anonIp6tablesSave := anonymizer.AnonymizeString(ip6tablesSave) + + // ULA IPv6 should be anonymized (global ID is a fingerprint) + assert.NotContains(t, anonIp6tablesSave, "fd00:1234::1/128") + + // Public IPv6 addresses should be anonymized + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::1") + assert.NotContains(t, anonIp6tablesSave, "2607:f8b0:4005::200e") + assert.NotContains(t, anonIp6tablesSave, "2001:db8::") + assert.Contains(t, anonIp6tablesSave, "2001:db8:ffff::") // Default anonymous v6 range + + // Structure should be preserved + assert.Contains(t, anonIp6tablesSave, "*filter") + assert.Contains(t, anonIp6tablesSave, "COMMIT") + assert.Contains(t, anonIp6tablesSave, "-j DROP") + assert.Contains(t, anonIp6tablesSave, "-j ACCEPT") } // TestAddConfig_AllFieldsCovered uses reflection to ensure every field in diff --git a/client/internal/dns.go b/client/internal/dns.go index f5040ee49..a6604810f 100644 --- a/client/internal/dns.go +++ b/client/internal/dns.go @@ -12,52 +12,83 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) -func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) { - ip, err := netip.ParseAddr(aRecord.RData) +func createPTRRecord(record nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) { + ip, err := netip.ParseAddr(record.RData) if err != nil { - log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err) + log.Warnf("failed to parse IP address %s: %v", record.RData, err) return nbdns.SimpleRecord{}, false } + ip = ip.Unmap() if !prefix.Contains(ip) { return nbdns.SimpleRecord{}, false } - ipOctets := strings.Split(ip.String(), ".") - slices.Reverse(ipOctets) - rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa") + var rdnsName string + if ip.Is4() { + octets := strings.Split(ip.String(), ".") + slices.Reverse(octets) + rdnsName = dns.Fqdn(strings.Join(octets, ".") + ".in-addr.arpa") + } else { + // Expand to full 32 nibbles in reverse order (LSB first) per RFC 3596. + raw := ip.As16() + nibbles := make([]string, 32) + for i := 0; i < 16; i++ { + nibbles[31-i*2] = fmt.Sprintf("%x", raw[i]>>4) + nibbles[31-i*2-1] = fmt.Sprintf("%x", raw[i]&0x0f) + } + rdnsName = dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa") + } return nbdns.SimpleRecord{ Name: rdnsName, Type: int(dns.TypePTR), - Class: aRecord.Class, - TTL: aRecord.TTL, - RData: dns.Fqdn(aRecord.Name), + Class: record.Class, + TTL: record.TTL, + RData: dns.Fqdn(record.Name), }, true } -// generateReverseZoneName creates the reverse DNS zone name for a given network +// generateReverseZoneName creates the reverse DNS zone name for a given network. +// For IPv4 it produces an in-addr.arpa name, for IPv6 an ip6.arpa name. func generateReverseZoneName(network netip.Prefix) (string, error) { - networkIP := network.Masked().Addr() + networkIP := network.Masked().Addr().Unmap() + bits := network.Bits() - if !networkIP.Is4() { - return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP) + if networkIP.Is4() { + // Round up to nearest byte. + octetsToUse := (bits + 7) / 8 + + octets := strings.Split(networkIP.String(), ".") + if octetsToUse > len(octets) { + return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", bits) + } + + reverseOctets := make([]string, octetsToUse) + for i := 0; i < octetsToUse; i++ { + reverseOctets[octetsToUse-1-i] = octets[i] + } + + return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil } - // round up to nearest byte - octetsToUse := (network.Bits() + 7) / 8 + // IPv6: round up to nearest nibble (4-bit boundary). + nibblesToUse := (bits + 3) / 4 - octets := strings.Split(networkIP.String(), ".") - if octetsToUse > len(octets) { - return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits()) + raw := networkIP.As16() + allNibbles := make([]string, 32) + for i := 0; i < 16; i++ { + allNibbles[i*2] = fmt.Sprintf("%x", raw[i]>>4) + allNibbles[i*2+1] = fmt.Sprintf("%x", raw[i]&0x0f) } - reverseOctets := make([]string, octetsToUse) - for i := 0; i < octetsToUse; i++ { - reverseOctets[octetsToUse-1-i] = octets[i] + // Take the first nibblesToUse nibbles (network portion), reverse them. + used := make([]string, nibblesToUse) + for i := 0; i < nibblesToUse; i++ { + used[nibblesToUse-1-i] = allNibbles[i] } - return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil + return dns.Fqdn(strings.Join(used, ".") + ".ip6.arpa"), nil } // zoneExists checks if a zone with the given name already exists in the configuration @@ -71,7 +102,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool { return false } -// collectPTRRecords gathers all PTR records for the given network from A records +// collectPTRRecords gathers all PTR records for the given network from A and AAAA records. func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord { var records []nbdns.SimpleRecord @@ -80,7 +111,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple continue } for _, record := range zone.Records { - if record.Type != int(dns.TypeA) { + if record.Type != int(dns.TypeA) && record.Type != int(dns.TypeAAAA) { continue } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b3908f163..0f4eb6bf8 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -298,6 +298,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() { ip = ip.Unmap() serverAddresses = append(serverAddresses, ip) + // Prefer the first IPv4 server as ServerIP since our DNS listener is IPv4. if !dnsSettings.ServerIP.IsValid() && ip.Is4() { dnsSettings.ServerIP = ip } diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index a67a23945..e9d310f00 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -13,7 +13,6 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" @@ -67,9 +66,9 @@ func (d *Resolver) Stop() { d.mu.Lock() defer d.mu.Unlock() - maps.Clear(d.records) - maps.Clear(d.domains) - maps.Clear(d.zones) + clear(d.records) + clear(d.domains) + clear(d.zones) } // ID returns the unique handler ID @@ -444,9 +443,9 @@ func (d *Resolver) Update(customZones []nbdns.CustomZone) { d.mu.Lock() defer d.mu.Unlock() - maps.Clear(d.records) - maps.Clear(d.domains) - maps.Clear(d.zones) + clear(d.records) + clear(d.domains) + clear(d.zones) for _, zone := range customZones { zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain))) diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index e4ccc8cbd..66d82dcd7 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -110,8 +110,25 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st connSettings.cleanDeprecatedSettings() - convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice()) - connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) + ipKey := networkManagerDbusIPv4Key + staleKey := networkManagerDbusIPv6Key + if config.ServerIP.Is6() { + ipKey = networkManagerDbusIPv6Key + staleKey = networkManagerDbusIPv4Key + raw := config.ServerIP.As16() + connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([][]byte{raw[:]}) + } else { + convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice()) + connSettings[ipKey][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) + } + + // Clear stale DNS settings from the opposite address family to avoid + // leftover entries if the server IP family changed. + if staleSettings, ok := connSettings[staleKey]; ok { + delete(staleSettings, networkManagerDbusDNSKey) + delete(staleSettings, networkManagerDbusDNSPriorityKey) + delete(staleSettings, networkManagerDbusDNSSearchKey) + } var ( searchDomains []string matchDomains []string @@ -146,8 +163,8 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st n.routingAll = false } - connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) - connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) + connSettings[ipKey][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) + connSettings[ipKey][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) state := &ShutdownState{ ManagerType: networkManager, diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index d4f54dec5..6fe2e21b6 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -410,7 +410,7 @@ func (s *DefaultServer) Stop() { log.Errorf("failed to disable DNS: %v", err) } - maps.Clear(s.extraDomains) + clear(s.extraDomains) } func (s *DefaultServer) disableDNS() (retErr error) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index f77f6e898..1026a29fc 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -347,7 +347,7 @@ func TestUpdateDNSServer(t *testing.T) { opts := iface.WGIFaceOpts{ IFaceName: fmt.Sprintf("utun230%d", n), - Address: fmt.Sprintf("100.66.100.%d/32", n+1), + Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)), WGPort: 33100, WGPrivKey: privKey.String(), MTU: iface.DefaultMTU, @@ -448,7 +448,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { privKey, _ := wgtypes.GeneratePrivateKey() opts := iface.WGIFaceOpts{ IFaceName: "utun2301", - Address: "100.66.100.1/32", + Address: wgaddr.MustParseWGAddress("100.66.100.1/32"), WGPort: 33100, WGPrivKey: privKey.String(), MTU: iface.DefaultMTU, @@ -929,7 +929,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { opts := iface.WGIFaceOpts{ IFaceName: "utun2301", - Address: "100.66.100.2/24", + Address: wgaddr.MustParseWGAddress("100.66.100.2/24"), WGPort: 33100, WGPrivKey: privKey.String(), MTU: iface.DefaultMTU, diff --git a/client/internal/dns/service.go b/client/internal/dns/service.go index 1c6ce7849..04bcd5985 100644 --- a/client/internal/dns/service.go +++ b/client/internal/dns/service.go @@ -16,8 +16,8 @@ const ( // This is used when the DNS server cannot bind port 53 directly // and needs firewall rules to redirect traffic. type Firewall interface { - AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error - RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error + AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error + RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error } type service interface { diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 4e09f1b7f..9c0e52af8 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -188,11 +188,10 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr { return s.listenIP } - -// evalListenAddress figure out the listen address for the DNS server -// first check the 53 port availability on WG interface or lo, if not success -// pick a random port on WG interface for eBPF, if not success -// check the 5053 port availability on WG interface or lo without eBPF usage, +// evalListenAddress figures out the listen address for the DNS server. +// IPv4-only: all peers have a v4 overlay address, and DNS config points to v4. +// First checks port 53 on WG interface or lo, then tries eBPF on a random port, +// then falls back to port 5053. func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) { if s.customAddr != nil { return s.customAddr.Addr(), s.customAddr.Port(), nil @@ -278,7 +277,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) { } ebpfSrv := ebpf.GetEbpfManagerInstance() - err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP.String(), int(port)) + err = ebpfSrv.LoadDNSFwd(s.wgInterface.Address().IP, int(port)) if err != nil { log.Warnf("failed to load DNS forwarder eBPF program, error: %s", err) return nil, 0, false diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index d9854c033..573dff540 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -90,8 +90,12 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { } func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { + family := int32(unix.AF_INET) + if config.ServerIP.Is6() { + family = unix.AF_INET6 + } defaultLinkInput := systemdDbusDNSInput{ - Family: unix.AF_INET, + Family: family, Address: config.ServerIP.AsSlice(), } if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 746b73ca7..a26536f6e 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -21,6 +21,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" @@ -29,6 +30,12 @@ import ( var currentMTU uint16 = iface.DefaultMTU +// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate. +type privateClientIface interface { + Name() string + Address() wgaddr.Address +} + func SetCurrentMTU(mtu uint16) { currentMTU = mtu } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index ee1ca42fe..988adb7d2 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -86,7 +86,7 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool { return false } -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { +func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { return &dns.Client{ Timeout: dialTimeout, Net: "udp", diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 1143b6c51..910c3779e 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -52,7 +52,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns return ExchangeWithFallback(ctx, client, r, upstream) } -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { +func GetClientPrivate(_ privateClientIface, _ netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { return &dns.Client{ Timeout: dialTimeout, Net: "udp", diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 02c11173b..0e04742a0 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -19,9 +19,7 @@ import ( type upstreamResolverIOS struct { *upstreamResolverBase - lIP netip.Addr - lNet netip.Prefix - interfaceName string + wgIface WGIface } func newUpstreamResolver( @@ -35,9 +33,7 @@ func newUpstreamResolver( ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, - lIP: wgIface.Address().IP, - lNet: wgIface.Address().Network, - interfaceName: wgIface.Name(), + wgIface: wgIface, } ios.upstreamClient = ios @@ -65,11 +61,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } else { upstreamIP = upstreamIP.Unmap() } - needsPrivate := u.lNet.Contains(upstreamIP) || + addr := u.wgIface.Address() + needsPrivate := addr.Network.Contains(upstreamIP) || + addr.IPv6Net.Contains(upstreamIP) || (u.routeMatch != nil && u.routeMatch(upstreamIP)) if needsPrivate { log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) - client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) + client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout) if err != nil { return nil, 0, fmt.Errorf("create private client: %s", err) } @@ -79,25 +77,33 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * return ExchangeWithFallback(nil, client, r, upstream) } -// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface -// This method is needed for iOS -func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { - index, err := getInterfaceIndex(interfaceName) +// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface. +// It selects the v6 bind address when the upstream is IPv6 and the interface has one, otherwise v4. +func GetClientPrivate(iface privateClientIface, upstreamIP netip.Addr, dialTimeout time.Duration) (*dns.Client, error) { + index, err := getInterfaceIndex(iface.Name()) if err != nil { - log.Debugf("unable to get interface index for %s: %s", interfaceName, err) + log.Debugf("unable to get interface index for %s: %s", iface.Name(), err) return nil, err } + addr := iface.Address() + bindIP := addr.IP + if upstreamIP.Is6() && addr.HasIPv6() { + bindIP = addr.IPv6 + } + + proto, opt := unix.IPPROTO_IP, unix.IP_BOUND_IF + if bindIP.Is6() { + proto, opt = unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF + } + dialer := &net.Dialer{ - LocalAddr: &net.UDPAddr{ - IP: ip.AsSlice(), - Port: 0, // Let the OS pick a free port - }, - Timeout: dialTimeout, + LocalAddr: net.UDPAddrFromAddrPort(netip.AddrPortFrom(bindIP, 0)), + Timeout: dialTimeout, Control: func(network, address string, c syscall.RawConn) error { var operr error fn := func(s uintptr) { - operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index) + operr = unix.SetsockoptInt(int(s), proto, opt, index) } if err := c.Control(fn); err != nil { diff --git a/client/internal/dns_test.go b/client/internal/dns_test.go new file mode 100644 index 000000000..e15cc8fb7 --- /dev/null +++ b/client/internal/dns_test.go @@ -0,0 +1,138 @@ +package internal + +import ( + "net/netip" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" +) + +func TestCreatePTRRecord_IPv4(t *testing.T) { + record := nbdns.SimpleRecord{ + Name: "peer1.netbird.cloud.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "100.64.0.5", + } + prefix := netip.MustParsePrefix("100.64.0.0/16") + + ptr, ok := createPTRRecord(record, prefix) + require.True(t, ok) + assert.Equal(t, "5.0.64.100.in-addr.arpa.", ptr.Name) + assert.Equal(t, int(dns.TypePTR), ptr.Type) + assert.Equal(t, "peer1.netbird.cloud.", ptr.RData) +} + +func TestCreatePTRRecord_IPv6(t *testing.T) { + record := nbdns.SimpleRecord{ + Name: "peer1.netbird.cloud.", + Type: int(dns.TypeAAAA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "fd00:1234:5678::1", + } + prefix := netip.MustParsePrefix("fd00:1234:5678::/48") + + ptr, ok := createPTRRecord(record, prefix) + require.True(t, ok) + assert.Equal(t, "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", ptr.Name) + assert.Equal(t, int(dns.TypePTR), ptr.Type) + assert.Equal(t, "peer1.netbird.cloud.", ptr.RData) +} + +func TestCreatePTRRecord_OutOfRange(t *testing.T) { + record := nbdns.SimpleRecord{ + Name: "peer1.netbird.cloud.", + Type: int(dns.TypeA), + RData: "10.0.0.1", + } + prefix := netip.MustParsePrefix("100.64.0.0/16") + + _, ok := createPTRRecord(record, prefix) + assert.False(t, ok) +} + +func TestGenerateReverseZoneName_IPv4(t *testing.T) { + tests := []struct { + prefix string + expected string + }{ + {"100.64.0.0/16", "64.100.in-addr.arpa."}, + {"10.0.0.0/8", "10.in-addr.arpa."}, + {"192.168.1.0/24", "1.168.192.in-addr.arpa."}, + } + + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix)) + require.NoError(t, err) + assert.Equal(t, tt.expected, zone) + }) + } +} + +func TestGenerateReverseZoneName_IPv6(t *testing.T) { + tests := []struct { + prefix string + expected string + }{ + {"fd00:1234:5678::/48", "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa."}, + {"fd00::/16", "0.0.d.f.ip6.arpa."}, + {"fd12:3456:789a:bcde::/64", "e.d.c.b.a.9.8.7.6.5.4.3.2.1.d.f.ip6.arpa."}, + } + + for _, tt := range tests { + t.Run(tt.prefix, func(t *testing.T) { + zone, err := generateReverseZoneName(netip.MustParsePrefix(tt.prefix)) + require.NoError(t, err) + assert.Equal(t, tt.expected, zone) + }) + } +} + +func TestCollectPTRRecords_BothFamilies(t *testing.T) { + config := &nbdns.Config{ + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.1"}, + {Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00::1"}, + {Name: "peer2.netbird.cloud.", Type: int(dns.TypeA), RData: "100.64.0.2"}, + }, + }, + }, + } + + v4Records := collectPTRRecords(config, netip.MustParsePrefix("100.64.0.0/16")) + assert.Len(t, v4Records, 2, "should collect 2 A record PTRs for the v4 prefix") + + v6Records := collectPTRRecords(config, netip.MustParsePrefix("fd00::/64")) + assert.Len(t, v6Records, 1, "should collect 1 AAAA record PTR for the v6 prefix") +} + +func TestAddReverseZone_IPv6(t *testing.T) { + config := &nbdns.Config{ + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud.", Type: int(dns.TypeAAAA), RData: "fd00:1234:5678::1"}, + }, + }, + }, + } + + addReverseZone(config, netip.MustParsePrefix("fd00:1234:5678::/48")) + + require.Len(t, config.CustomZones, 2) + reverseZone := config.CustomZones[1] + assert.Equal(t, "8.7.6.5.4.3.2.1.0.0.d.f.ip6.arpa.", reverseZone.Domain) + assert.Len(t, reverseZone.Records, 1) + assert.Equal(t, int(dns.TypePTR), reverseZone.Records[0].Type) +} diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 58b88d9ef..c4c16cd3f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -80,6 +80,7 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } + // IPv4-only: peers reach the forwarder via its v4 overlay address. localAddr := m.wgIface.Address().IP if localAddr.IsValid() && m.firewall != nil { diff --git a/client/internal/ebpf/ebpf/dns_fwd_linux.go b/client/internal/ebpf/ebpf/dns_fwd_linux.go index 93797da76..1e7774573 100644 --- a/client/internal/ebpf/ebpf/dns_fwd_linux.go +++ b/client/internal/ebpf/ebpf/dns_fwd_linux.go @@ -2,7 +2,8 @@ package ebpf import ( "encoding/binary" - "net" + "fmt" + "net/netip" log "github.com/sirupsen/logrus" ) @@ -12,7 +13,7 @@ const ( mapKeyDNSPort uint32 = 1 ) -func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error { +func (tf *GeneralManager) LoadDNSFwd(ip netip.Addr, dnsPort int) error { log.Debugf("load eBPF DNS forwarder, watching addr: %s:53, redirect to port: %d", ip, dnsPort) tf.lock.Lock() defer tf.lock.Unlock() @@ -22,7 +23,11 @@ func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error { return err } - err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip)) + if !ip.Is4() { + return fmt.Errorf("eBPF DNS forwarder only supports IPv4, got %s", ip) + } + ip4 := ip.As4() + err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, binary.BigEndian.Uint32(ip4[:])) if err != nil { return err } @@ -45,7 +50,3 @@ func (tf *GeneralManager) FreeDNSFwd() error { return tf.unsetFeatureFlag(featureFlagDnsForwarder) } -func ip2int(ipString string) uint32 { - ip := net.ParseIP(ipString) - return binary.BigEndian.Uint32(ip.To4()) -} diff --git a/client/internal/ebpf/manager/manager.go b/client/internal/ebpf/manager/manager.go index af10142d5..25a767090 100644 --- a/client/internal/ebpf/manager/manager.go +++ b/client/internal/ebpf/manager/manager.go @@ -1,8 +1,10 @@ package manager +import "net/netip" + // Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy type Manager interface { - LoadDNSFwd(ip string, dnsPort int) error + LoadDNSFwd(ip netip.Addr, dnsPort int) error FreeDNSFwd() error LoadWgProxy(proxyPort, wgPort int) error FreeWGProxy() error diff --git a/client/internal/engine.go b/client/internal/engine.go index 7f19e2d28..66fe6056b 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -33,6 +33,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/dns" @@ -64,6 +65,7 @@ import ( mgm "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" @@ -88,8 +90,9 @@ type EngineConfig struct { WgPort int WgIfaceName string - // WgAddr is a Wireguard local address (Netbird Network IP) - WgAddr string + // WgAddr is the Wireguard local address (Netbird Network IP). + // Contains both v4 and optional v6 overlay addresses. + WgAddr wgaddr.Address // WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine) WgPrivateKey wgtypes.Key @@ -134,6 +137,7 @@ type EngineConfig struct { DisableFirewall bool BlockLANAccess bool BlockInbound bool + DisableIPv6 bool LazyConnectionEnabled bool @@ -644,7 +648,7 @@ func (e *Engine) initFirewall() error { rosenpassPort := e.rpManager.GetAddress().Port port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}} - // this rule is static and will be torn down on engine down by the firewall manager + // IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4. if _, err := e.firewall.AddPeerFiltering( nil, net.IP{0, 0, 0, 0}, @@ -696,10 +700,15 @@ func (e *Engine) blockLanAccess() { log.Infof("blocking route LAN access for networks: %v", toBlock) v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6 := netip.PrefixFrom(netip.IPv6Unspecified(), 0) for _, network := range toBlock { + source := v4 + if network.Addr().Is6() { + source = v6 + } if _, err := e.firewall.AddRouteFiltering( nil, - []netip.Prefix{v4}, + []netip.Prefix{source}, firewallManager.Network{Prefix: network}, firewallManager.ProtocolALL, nil, @@ -737,7 +746,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { if !ok { continue } - if !compareNetIPLists(allowedIPs, p.GetAllowedIps()) { + if !compareNetIPLists(allowedIPs, e.filterAllowedIPs(p.GetAllowedIps())) { modified = append(modified, p) continue } @@ -1016,6 +1025,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { e.config.DisableFirewall, e.config.BlockLANAccess, e.config.BlockInbound, + e.config.DisableIPv6, e.config.LazyConnectionEnabled, e.config.EnableSSHRoot, e.config.EnableSSHSFTP, @@ -1043,6 +1053,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { return ErrResetConnection } + if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) { + log.Infof("peer IPv6 address changed, restarting client") + _ = CtxGetState(e.ctx).Wrap(ErrResetConnection) + e.clientCancel() + return ErrResetConnection + } + if conf.GetSshConfig() != nil { if err := e.updateSSH(conf.GetSshConfig()); err != nil { log.Warnf("failed handling SSH server setup: %v", err) @@ -1051,6 +1068,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { state := e.statusRecorder.GetLocalPeerState() state.IP = e.wgInterface.Address().String() + state.IPv6 = e.wgInterface.Address().IPv6String() state.PubKey = e.config.WgPrivateKey.PublicKey().String() state.KernelInterface = !e.wgInterface.IsUserspaceBind() state.FQDN = conf.GetFqdn() @@ -1059,6 +1077,28 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { return nil } + +// hasIPv6Changed reports whether the IPv6 overlay address in the peer config +// differs from the configured address (added, removed, or changed). +// Compares against e.config.WgAddr (not the interface address, which may have +// been cleared by ClearIPv6 if OS assignment failed). +func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool { + current := e.config.WgAddr + raw := conf.GetAddressV6() + + if len(raw) == 0 { + return current.HasIPv6() + } + + prefix, err := netiputil.DecodePrefix(raw) + if err != nil { + log.Errorf("decode v6 overlay address: %v", err) + return false + } + + return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked() +} + func (e *Engine) receiveJobEvents() { e.jobExecutorWG.Add(1) go func() { @@ -1157,6 +1197,7 @@ func (e *Engine) receiveManagementEvents() { e.config.DisableFirewall, e.config.BlockLANAccess, e.config.BlockInbound, + e.config.DisableIPv6, e.config.LazyConnectionEnabled, e.config.EnableSSHRoot, e.config.EnableSSHSFTP, @@ -1256,7 +1297,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network) + dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address()) if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { log.Errorf("failed to update dns server, err: %v", err) @@ -1411,7 +1452,9 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE return entries } -func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { +func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, addr wgaddr.Address) nbdns.Config { + network := addr.Network + networkV6 := addr.IPv6Net //nolint forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) if forwarderPort == 0 { @@ -1468,6 +1511,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns if len(dnsUpdate.CustomZones) > 0 { addReverseZone(&dnsUpdate, network) + if networkV6.IsValid() { + addReverseZone(&dnsUpdate, networkV6) + } } return dnsUpdate @@ -1477,8 +1523,10 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { replacement := make([]peer.State, len(offlinePeers)) for i, offlinePeer := range offlinePeers { log.Debugf("added offline peer %s", offlinePeer.Fqdn) + v4, v6 := overlayAddrsFromAllowedIPs(offlinePeer.GetAllowedIps(), e.wgInterface.Address().IPv6Net) replacement[i] = peer.State{ - IP: strings.Join(offlinePeer.GetAllowedIps(), ","), + IP: addrToString(v4), + IPv6: addrToString(v6), PubKey: offlinePeer.GetWgPubKey(), FQDN: offlinePeer.GetFqdn(), ConnStatus: peer.StatusIdle, @@ -1489,6 +1537,37 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) { e.statusRecorder.ReplaceOfflinePeers(replacement) } +// overlayAddrsFromAllowedIPs extracts the peer's v4 and v6 overlay addresses +// from AllowedIPs strings. Only host routes (/32, /128) are considered; v6 must +// fall within ourV6Net to distinguish overlay addresses from routed prefixes. +func overlayAddrsFromAllowedIPs(allowedIPs []string, ourV6Net netip.Prefix) (v4, v6 netip.Addr) { + for _, cidr := range allowedIPs { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + log.Warnf("failed to parse AllowedIP %q: %v", cidr, err) + continue + } + addr := prefix.Addr().Unmap() + switch { + case addr.Is4() && prefix.Bits() == 32 && !v4.IsValid(): + v4 = addr + case addr.Is6() && prefix.Bits() == 128 && ourV6Net.Contains(addr) && !v6.IsValid(): + v6 = addr + } + if v4.IsValid() && v6.IsValid() { + break + } + } + return +} + +func addrToString(addr netip.Addr) string { + if !addr.IsValid() { + return "" + } + return addr.String() +} + // addNewPeers adds peers that were not know before but arrived from the Management service with the update func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { for _, p := range peersUpdate { @@ -1514,15 +1593,23 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { log.Errorf("failed to parse allowedIPS: %v", err) return err } + if allowedNetIP.Addr().Is6() && !e.wgInterface.Address().HasIPv6() { + continue + } peerIPs = append(peerIPs, allowedNetIP) } + if len(peerIPs) == 0 { + return fmt.Errorf("peer %s has no usable AllowedIPs", peerKey) + } + conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion) if err != nil { return fmt.Errorf("create peer connection: %w", err) } - err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String()) + peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net) + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, addrToString(peerV4), addrToString(peerV6)) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) } @@ -1757,6 +1844,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err e.config.DisableFirewall, e.config.BlockLANAccess, e.config.BlockInbound, + e.config.DisableIPv6, e.config.LazyConnectionEnabled, e.config.EnableSSHRoot, e.config.EnableSSHSFTP, @@ -1770,7 +1858,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err return nil, nil, false, err } routes := toRoutes(netMap.GetRoutes()) - dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network) + dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address()) dnsFeatureFlag := toDNSFeatureFlag(netMap) return routes, &dnsCfg, dnsFeatureFlag, nil } @@ -1812,7 +1900,10 @@ func (e *Engine) wgInterfaceCreate() (err error) { case "android": err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP().String(), e.dnsServer.SearchDomains()) case "ios": - e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr) + e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr.String()) + if e.config.WgAddr.HasIPv6() { + e.mobileDep.NetworkChangeListener.SetInterfaceIPv6(e.config.WgAddr.IPv6String()) + } err = e.wgInterface.Create() default: err = e.wgInterface.Create() @@ -2089,6 +2180,14 @@ func (e *Engine) GetWgAddr() netip.Addr { return e.wgInterface.Address().IP } +// GetWgV6Addr returns the IPv6 overlay address of the WireGuard interface. +func (e *Engine) GetWgV6Addr() netip.Addr { + if e.wgInterface == nil { + return netip.Addr{} + } + return e.wgInterface.Address().IPv6 +} + func (e *Engine) RenewTun(fd int) error { e.syncMsgMux.Lock() wgInterface := e.wgInterface @@ -2370,8 +2469,7 @@ func getInterfacePrefixes() ([]netip.Prefix, error) { prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked() ip := prefix.Addr() - // TODO: add IPv6 - if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + if ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { continue } @@ -2382,6 +2480,24 @@ func getInterfacePrefixes() ([]netip.Prefix, error) { return prefixes, nberrors.FormatErrorOrNil(merr) } +// filterAllowedIPs strips IPv6 entries when the local interface has no v6 address. +// This covers both the explicit --disable-ipv6 flag (v6 never assigned) and the +// case where OS v6 assignment failed (ClearIPv6). Without this, WireGuard would +// accept v6 traffic that the native firewall cannot filter. +func (e *Engine) filterAllowedIPs(ips []string) []string { + if e.wgInterface.Address().HasIPv6() { + return ips + } + filtered := make([]string, 0, len(ips)) + for _, s := range ips { + p, err := netip.ParsePrefix(s) + if err != nil || !p.Addr().Is6() { + filtered = append(filtered, s) + } + } + return filtered +} + // compareNetIPLists compares a list of netip.Prefix with a list of strings. // return true if both lists are equal, false otherwise. func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool { diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go index 1419bc262..53d2c1122 100644 --- a/client/internal/engine_ssh.go +++ b/client/internal/engine_ssh.go @@ -41,6 +41,14 @@ func (e *Engine) setupSSHPortRedirection() error { } log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr) + if v6 := e.wgInterface.Address().IPv6; v6.IsValid() { + if err := e.firewall.AddInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil { + log.Warnf("failed to add IPv6 SSH port redirection: %v", err) + } else { + log.Infof("SSH port redirection enabled: [%s]:22 -> [%s]:22022", v6, v6) + } + } + return nil } @@ -137,12 +145,13 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) [] continue } - peerIP := e.extractPeerIP(peerConfig) + peerV4, peerV6 := overlayAddrsFromAllowedIPs(peerConfig.GetAllowedIps(), e.wgInterface.Address().IPv6Net) hostname := e.extractHostname(peerConfig) peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{ Hostname: hostname, - IP: peerIP, + IP: peerV4, + IPv6: peerV6, FQDN: peerConfig.GetFqdn(), }) } @@ -150,18 +159,6 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) [] return peerInfo } -// extractPeerIP extracts IP address from peer's allowed IPs -func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string { - if len(peerConfig.GetAllowedIps()) == 0 { - return "" - } - - if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil { - return prefix.Addr().String() - } - return "" -} - // extractHostname extracts short hostname from FQDN func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string { fqdn := peerConfig.GetFqdn() @@ -208,7 +205,7 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) { fullStatus := statusRecorder.GetFullStatus() for _, peerState := range fullStatus.Peers { - if peerState.IP == peerAddress || peerState.FQDN == peerAddress { + if peerState.IP == peerAddress || peerState.FQDN == peerAddress || peerState.IPv6 == peerAddress { if len(peerState.SSHHostKey) > 0 { return peerState.SSHHostKey, true } @@ -262,6 +259,13 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error { return fmt.Errorf("start SSH server: %w", err) } + if v6 := wgAddr.IPv6; v6.IsValid() { + v6Addr := netip.AddrPortFrom(v6, sshserver.InternalSSHPort) + if err := server.AddListener(e.ctx, v6Addr); err != nil { + log.Warnf("failed to add IPv6 SSH listener: %v", err) + } + } + e.sshServer = server if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { @@ -330,6 +334,12 @@ func (e *Engine) cleanupSSHPortRedirection() error { } log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr) + if v6 := e.wgInterface.Address().IPv6; v6.IsValid() { + if err := e.firewall.RemoveInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil { + log.Debugf("failed to remove IPv6 SSH port redirection: %v", err) + } + } + return nil } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f4c5be70a..834a49a09 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -67,6 +67,7 @@ import ( mgmt "github.com/netbirdio/netbird/shared/management/client" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" relayClient "github.com/netbirdio/netbird/shared/relay/client" + "github.com/netbirdio/netbird/shared/netiputil" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/shared/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" @@ -95,7 +96,7 @@ type MockWGIface struct { AddressFunc func() wgaddr.Address ToInterfaceFunc func() *net.Interface UpFunc func() (*udpmux.UniversalUDPMuxDefault, error) - UpdateAddrFunc func(newAddr string) error + UpdateAddrFunc func(newAddr wgaddr.Address) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeerFunc func(peerKey string) error AddAllowedIPFunc func(peerKey string, allowedIP netip.Prefix) error @@ -157,7 +158,7 @@ func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { return m.UpFunc() } -func (m *MockWGIface) UpdateAddr(newAddr string) error { +func (m *MockWGIface) UpdateAddr(newAddr wgaddr.Address) error { return m.UpdateAddrFunc(newAddr) } @@ -254,7 +255,7 @@ func TestEngine_SSH(t *testing.T) { ctx, cancel, &EngineConfig{ WgIfaceName: "utun101", - WgAddr: "100.64.0.1/24", + WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"), WgPrivateKey: key, WgPort: 33100, ServerSSHAllowed: true, @@ -431,7 +432,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", + WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"), WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, @@ -655,7 +656,7 @@ func TestEngine_Sync(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: "utun103", - WgAddr: "100.64.0.1/24", + WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"), WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, @@ -825,7 +826,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, - WgAddr: wgAddr, + WgAddr: wgaddr.MustParseWGAddress(wgAddr), WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, @@ -843,7 +844,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { opts := iface.WGIFaceOpts{ IFaceName: wgIfaceName, - Address: wgAddr, + Address: wgaddr.MustParseWGAddress(wgAddr), WGPort: engine.config.WgPort, WGPrivKey: key.String(), MTU: iface.DefaultMTU, @@ -1032,7 +1033,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &EngineConfig{ WgIfaceName: wgIfaceName, - WgAddr: wgAddr, + WgAddr: wgaddr.MustParseWGAddress(wgAddr), WgPrivateKey: key, WgPort: 33100, MTU: iface.DefaultMTU, @@ -1050,7 +1051,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { } opts := iface.WGIFaceOpts{ IFaceName: wgIfaceName, - Address: wgAddr, + Address: wgaddr.MustParseWGAddress(wgAddr), WGPort: 33100, WGPrivKey: key.String(), MTU: iface.DefaultMTU, @@ -1555,7 +1556,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin wgPort := 33100 + i conf := &EngineConfig{ WgIfaceName: ifaceName, - WgAddr: resp.PeerConfig.Address, + WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address), WgPrivateKey: key, WgPort: wgPort, MTU: iface.DefaultMTU, @@ -1705,3 +1706,224 @@ func getPeers(e *Engine) int { return len(e.peerStore.PeersPubKey()) } + +func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte { + t.Helper() + b, err := netiputil.EncodePrefix(p) + require.NoError(t, err) + return b +} + +func TestEngine_hasIPv6Changed(t *testing.T) { + v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16") + + v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16") + v4v6.IPv6 = netip.MustParseAddr("fd00::1") + v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked() + + tests := []struct { + name string + current wgaddr.Address + confV6 []byte + expected bool + }{ + { + name: "no v6 before, no v6 now", + current: v4Only, + confV6: nil, + expected: false, + }, + { + name: "no v6 before, v6 added", + current: v4Only, + confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")), + expected: true, + }, + { + name: "had v6, now removed", + current: v4v6, + confV6: nil, + expected: true, + }, + { + name: "had v6, same v6", + current: v4v6, + confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")), + expected: false, + }, + { + name: "had v6, different v6", + current: v4v6, + confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::2/64")), + expected: true, + }, + { + name: "same v6 addr, different prefix length", + current: v4v6, + confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/80")), + expected: true, + }, + { + name: "decode error keeps status quo", + current: v4Only, + confV6: []byte{1, 2, 3}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + engine := &Engine{ + config: &EngineConfig{WgAddr: tt.current}, + } + conf := &mgmtProto.PeerConfig{ + AddressV6: tt.confV6, + } + assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf)) + }) + } +} + +func TestFilterAllowedIPs(t *testing.T) { + v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16") + v4v6Addr.IPv6 = netip.MustParseAddr("fd00::1") + v4v6Addr.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked() + + v4OnlyAddr := wgaddr.MustParseWGAddress("100.64.0.1/16") + + tests := []struct { + name string + addr wgaddr.Address + input []string + expected []string + }{ + { + name: "interface has v6, keep all", + addr: v4v6Addr, + input: []string{"100.64.0.1/32", "fd00::1/128"}, + expected: []string{"100.64.0.1/32", "fd00::1/128"}, + }, + { + name: "no v6, strip v6", + addr: v4OnlyAddr, + input: []string{"100.64.0.1/32", "fd00::1/128"}, + expected: []string{"100.64.0.1/32"}, + }, + { + name: "no v6, only v4", + addr: v4OnlyAddr, + input: []string{"100.64.0.1/32", "10.0.0.0/8"}, + expected: []string{"100.64.0.1/32", "10.0.0.0/8"}, + }, + { + name: "no v6, only v6 input", + addr: v4OnlyAddr, + input: []string{"fd00::1/128", "::/0"}, + expected: []string{}, + }, + { + name: "no v6, invalid prefix preserved", + addr: v4OnlyAddr, + input: []string{"100.64.0.1/32", "garbage"}, + expected: []string{"100.64.0.1/32", "garbage"}, + }, + { + name: "no v6, empty input", + addr: v4OnlyAddr, + input: []string{}, + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr := tt.addr + engine := &Engine{ + config: &EngineConfig{}, + wgInterface: &MockWGIface{ + AddressFunc: func() wgaddr.Address { return addr }, + }, + } + result := engine.filterAllowedIPs(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestOverlayAddrsFromAllowedIPs(t *testing.T) { + ourV6Net := netip.MustParsePrefix("fd00:1234:5678:abcd::/64") + + tests := []struct { + name string + allowedIPs []string + ourV6Net netip.Prefix + wantV4 string + wantV6 string + }{ + { + name: "v4 only", + allowedIPs: []string{"100.64.0.1/32"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "", + }, + { + name: "v4 and v6 overlay", + allowedIPs: []string{"100.64.0.1/32", "fd00:1234:5678:abcd::1/128"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "fd00:1234:5678:abcd::1", + }, + { + name: "v4, routed v6, overlay v6", + allowedIPs: []string{"100.64.0.1/32", "2001:db8::1/128", "fd00:1234:5678:abcd::1/128"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "fd00:1234:5678:abcd::1", + }, + { + name: "routed v6 /128 outside our subnet is ignored", + allowedIPs: []string{"100.64.0.1/32", "2001:db8::1/128"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "", + }, + { + name: "routed v6 prefix is ignored", + allowedIPs: []string{"100.64.0.1/32", "fd00:1234:5678:abcd::/64"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "", + }, + { + name: "no v6 subnet configured", + allowedIPs: []string{"100.64.0.1/32", "fd00:1234:5678:abcd::1/128"}, + ourV6Net: netip.Prefix{}, + wantV4: "100.64.0.1", + wantV6: "", + }, + { + name: "v4 /24 route is ignored", + allowedIPs: []string{"100.64.0.0/24", "100.64.0.1/32"}, + ourV6Net: ourV6Net, + wantV4: "100.64.0.1", + wantV6: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v4, v6 := overlayAddrsFromAllowedIPs(tt.allowedIPs, tt.ourV6Net) + if tt.wantV4 == "" { + assert.False(t, v4.IsValid(), "expected no v4") + } else { + assert.Equal(t, tt.wantV4, v4.String(), "v4") + } + if tt.wantV6 == "" { + assert.False(t, v6.IsValid(), "expected no v6") + } else { + assert.Equal(t, tt.wantV6, v6.String(), "v6") + } + }) + } +} diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 39e9bacfa..2eeac1954 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -26,7 +26,7 @@ type wgIfaceBase interface { Address() wgaddr.Address ToInterface() *net.Interface Up() (*udpmux.UniversalUDPMuxDefault, error) - UpdateAddr(newAddr string) error + UpdateAddr(newAddr wgaddr.Address) error GetProxy() wgproxy.Proxy GetProxyPort() uint16 UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go index 792d04215..60b8baadb 100644 --- a/client/internal/lazyconn/activity/listener_bind.go +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -57,6 +57,7 @@ func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyc // deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. // Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). // It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +// For IPv6-only peers, the last two bytes of the v6 address are used. func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { if len(allowedIPs) == 0 { return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") @@ -64,6 +65,7 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e ourNetwork := wgIface.Address().Network + // Try v4 first (preferred: deterministic from overlay IP) var peerIP netip.Addr for _, allowedIP := range allowedIPs { ip := allowedIP.Addr() @@ -76,13 +78,24 @@ func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, e } } - if !peerIP.IsValid() { - return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + if peerIP.IsValid() { + octets := peerIP.As4() + return netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}), nil } - octets := peerIP.As4() - fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) - return fakeIP, nil + // Fallback: use last two bytes of first v6 overlay IP + addr := wgIface.Address() + if addr.IPv6Net.IsValid() { + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if ip.Is6() && addr.IPv6Net.Contains(ip) { + raw := ip.As16() + return netip.AddrFrom4([4]byte{127, 2, raw[14], raw[15]}), nil + } + } + } + + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") } func (d *BindListener) setupLazyConn() error { diff --git a/client/internal/listener/network_change.go b/client/internal/listener/network_change.go index 08bf5fd52..e0aa43abe 100644 --- a/client/internal/listener/network_change.go +++ b/client/internal/listener/network_change.go @@ -5,4 +5,5 @@ type NetworkChangeListener interface { // OnNetworkChanged invoke when network settings has been changed OnNetworkChanged(string) SetInterfaceIP(string) + SetInterfaceIPv6(string) } diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index 2420b1fdf..6f1da5138 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -316,7 +316,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) { case nftypes.TCP, nftypes.UDP, nftypes.SCTP: srcPort = flow.TupleOrig.Proto.SourcePort dstPort = flow.TupleOrig.Proto.DestinationPort - case nftypes.ICMP: + case nftypes.ICMP, nftypes.ICMPv6: icmpType = flow.TupleOrig.Proto.ICMPType icmpCode = flow.TupleOrig.Proto.ICMPCode } @@ -359,8 +359,14 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool { } // fallback if mark rules are not in place - wgnet := c.iface.Address().Network - return wgnet.Contains(srcIP) || wgnet.Contains(dstIP) + addr := c.iface.Address() + if addr.Network.Contains(srcIP) || addr.Network.Contains(dstIP) { + return true + } + if addr.IPv6Net.IsValid() { + return addr.IPv6Net.Contains(srcIP) || addr.IPv6Net.Contains(dstIP) + } + return false } // mapRxPackets maps packet counts to RX based on flow direction @@ -419,17 +425,16 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes } // fallback if marks are not set - wgaddr := c.iface.Address().IP - wgnetwork := c.iface.Address().Network + addr := c.iface.Address() switch { - case wgaddr == srcIP: + case addr.IP == srcIP || (addr.IPv6.IsValid() && addr.IPv6 == srcIP): return nftypes.Egress - case wgaddr == dstIP: + case addr.IP == dstIP || (addr.IPv6.IsValid() && addr.IPv6 == dstIP): return nftypes.Ingress - case wgnetwork.Contains(srcIP): + case addr.Network.Contains(srcIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(srcIP)): // netbird network -> resource network return nftypes.Ingress - case wgnetwork.Contains(dstIP): + case addr.Network.Contains(dstIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(dstIP)): // resource network -> netbird network return nftypes.Egress } diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index a033a2a7c..8f8e68784 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -24,15 +24,17 @@ type Logger struct { cancel context.CancelFunc statusRecorder *peer.Status wgIfaceNet netip.Prefix + wgIfaceNetV6 netip.Prefix dnsCollection atomic.Bool exitNodeCollection atomic.Bool Store types.Store } -func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger { +func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix) *Logger { return &Logger{ statusRecorder: statusRecorder, wgIfaceNet: wgIfaceIPNet, + wgIfaceNetV6: wgIfaceIPNetV6, Store: store.NewMemoryStore(), } } @@ -88,11 +90,11 @@ func (l *Logger) startReceiver() { var isSrcExitNode bool var isDestExitNode bool - if !l.wgIfaceNet.Contains(event.SourceIP) { + if !l.isOverlayIP(event.SourceIP) { event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) } - if !l.wgIfaceNet.Contains(event.DestIP) { + if !l.isOverlayIP(event.DestIP) { event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP) } @@ -136,6 +138,10 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { l.exitNodeCollection.Store(exitNodeCollection) } +func (l *Logger) isOverlayIP(ip netip.Addr) bool { + return l.wgIfaceNet.Contains(ip) || (l.wgIfaceNetV6.IsValid() && l.wgIfaceNetV6.Contains(ip)) +} + func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection if !l.dnsCollection.Load() && event.Protocol == types.UDP && diff --git a/client/internal/netflow/logger/logger_test.go b/client/internal/netflow/logger/logger_test.go index 1144544d8..ad2eedef2 100644 --- a/client/internal/netflow/logger/logger_test.go +++ b/client/internal/netflow/logger/logger_test.go @@ -12,7 +12,7 @@ import ( ) func TestStore(t *testing.T) { - logger := logger.New(nil, netip.Prefix{}) + logger := logger.New(nil, netip.Prefix{}, netip.Prefix{}) logger.Enable() event := types.EventFields{ diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index 7752c97b0..eff083dbf 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -35,11 +35,12 @@ type Manager struct { // NewManager creates a new netflow manager func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { - var prefix netip.Prefix + var prefix, prefixV6 netip.Prefix if iface != nil { prefix = iface.Address().Network + prefixV6 = iface.Address().IPv6Net } - flowLogger := logger.New(statusRecorder, prefix) + flowLogger := logger.New(statusRecorder, prefix, prefixV6) var ct nftypes.ConnTracker if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { @@ -269,7 +270,7 @@ func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent { }, } - if event.Protocol == nftypes.ICMP { + if event.Protocol == nftypes.ICMP || event.Protocol == nftypes.ICMPv6 { protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{ IcmpInfo: &proto.ICMPInfo{ IcmpType: uint32(event.ICMPType), diff --git a/client/internal/netflow/types/types.go b/client/internal/netflow/types/types.go index f76146ba3..3f7d0d0ad 100644 --- a/client/internal/netflow/types/types.go +++ b/client/internal/netflow/types/types.go @@ -19,6 +19,7 @@ const ( ICMP = Protocol(1) TCP = Protocol(6) UDP = Protocol(17) + ICMPv6 = Protocol(58) SCTP = Protocol(132) ) @@ -30,6 +31,8 @@ func (p Protocol) String() string { return "TCP" case 17: return "UDP" + case 58: + return "ICMPv6" case 132: return "SCTP" default: diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index e8e61f660..df746fa13 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -53,6 +53,7 @@ type RouterState struct { type State struct { Mux *sync.RWMutex IP string + IPv6 string PubKey string FQDN string ConnStatus ConnStatus @@ -106,6 +107,7 @@ func (s *State) GetRoutes() map[string]struct{} { // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { IP string + IPv6 string PubKey string KernelInterface bool FQDN string @@ -259,7 +261,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) { } // AddPeer adds peer to Daemon status map -func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string) error { +func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string) error { d.mux.Lock() defer d.mux.Unlock() @@ -270,6 +272,7 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string) error { d.peers[peerPubKey] = State{ PubKey: peerPubKey, IP: ip, + IPv6: ipv6, ConnStatus: StatusIdle, FQDN: fqdn, Mux: new(sync.RWMutex), @@ -710,6 +713,9 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { d.localPeer = localPeerState fqdn := d.localPeer.FQDN ip := d.localPeer.IP + if d.localPeer.IPv6 != "" { + ip = ip + "\n" + d.localPeer.IPv6 + } d.mux.Unlock() d.notifier.localAddressChanged(fqdn, ip) @@ -1316,6 +1322,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus { } pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP + pbFullStatus.LocalPeerState.Ipv6 = fs.LocalPeerState.IPv6 pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN @@ -1331,6 +1338,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus { pbPeerState := &proto.PeerState{ IP: peerState.IP, + Ipv6: peerState.IPv6, PubKey: peerState.PubKey, ConnStatus: peerState.ConnStatus.String(), ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 272638750..9bafca55a 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -8,19 +8,20 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAddPeer(t *testing.T) { key := "abc" ip := "100.108.254.1" status := NewRecorder("https://mgm") - err := status.AddPeer(key, "abc.netbird", ip) + err := status.AddPeer(key, "abc.netbird", ip, "") assert.NoError(t, err, "shouldn't return error") _, exists := status.peers[key] assert.True(t, exists, "value was found") - err = status.AddPeer(key, "abc.netbird", ip) + err = status.AddPeer(key, "abc.netbird", ip, "") assert.Error(t, err, "should return error on duplicate") } @@ -29,7 +30,7 @@ func TestGetPeer(t *testing.T) { key := "abc" ip := "100.108.254.1" status := NewRecorder("https://mgm") - err := status.AddPeer(key, "abc.netbird", ip) + err := status.AddPeer(key, "abc.netbird", ip, "") assert.NoError(t, err, "shouldn't return error") peerStatus, err := status.GetPeer(key) @@ -46,7 +47,7 @@ func TestUpdatePeerState(t *testing.T) { ip := "10.10.10.10" fqdn := "peer-a.netbird.local" status := NewRecorder("https://mgm") - _ = status.AddPeer(key, fqdn, ip) + require.NoError(t, status.AddPeer(key, fqdn, ip, "")) peerState := State{ PubKey: key, @@ -85,7 +86,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { key := "abc" ip := "10.10.10.10" status := NewRecorder("https://mgm") - _ = status.AddPeer(key, "abc.netbird", ip) + _ = status.AddPeer(key, "abc.netbird", ip, "") sub := status.SubscribeToPeerStateChanges(context.Background(), key) assert.NotNil(t, sub, "channel shouldn't be nil") diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 20c615d57..cd5bc0680 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "net" "net/url" "os" "os/user" @@ -89,6 +90,7 @@ type ConfigInput struct { DisableFirewall *bool BlockLANAccess *bool BlockInbound *bool + DisableIPv6 *bool DisableNotifications *bool @@ -127,6 +129,7 @@ type Config struct { DisableFirewall bool BlockLANAccess bool BlockInbound bool + DisableIPv6 bool DisableNotifications *bool @@ -542,6 +545,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } + if input.DisableIPv6 != nil && *input.DisableIPv6 != config.DisableIPv6 { + log.Infof("setting IPv6 overlay disabled=%v", *input.DisableIPv6) + config.DisableIPv6 = *input.DisableIPv6 + updated = true + } + if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { if *input.DisableNotifications { log.Infof("disabling notifications") @@ -751,8 +760,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri return config, nil } - newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d", - config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443)) + newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s", config.ManagementURL.Scheme, net.JoinHostPort(defaultManagementURL.Hostname(), "443"))) if err != nil { return nil, err } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 59be5b0a7..f00a8d93a 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "strconv" "sync" "time" @@ -257,7 +258,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri } }() - turnServerAddr := fmt.Sprintf("%s:%d", uri.Host, uri.Port) + turnServerAddr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port)) var conn net.PacketConn switch uri.Proto { diff --git a/client/internal/rosenpass/manager.go b/client/internal/rosenpass/manager.go index 1faa22dc5..11cda8dbc 100644 --- a/client/internal/rosenpass/manager.go +++ b/client/internal/rosenpass/manager.go @@ -75,7 +75,7 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar if err != nil { return fmt.Errorf("failed to parse rosenpass address: %w", err) } - peerAddr := fmt.Sprintf("%s:%s", wireGuardIP, strPort) + peerAddr := net.JoinHostPort(wireGuardIP, strPort) if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil { return fmt.Errorf("failed to resolve peer endpoint address: %w", err) } @@ -259,6 +259,9 @@ func findRandomAvailableUDPPort() (int, error) { } defer conn.Close() - splitAddress := strings.Split(conn.LocalAddr().String(), ":") - return strconv.Atoi(splitAddress[len(splitAddress)-1]) + _, portStr, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + return 0, fmt.Errorf("parse local address %s: %w", conn.LocalAddr(), err) + } + return strconv.Atoi(portStr) } diff --git a/client/internal/rosenpass/manager_test.go b/client/internal/rosenpass/manager_test.go new file mode 100644 index 000000000..90bbdda59 --- /dev/null +++ b/client/internal/rosenpass/manager_test.go @@ -0,0 +1,14 @@ +package rosenpass + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFindRandomAvailableUDPPort(t *testing.T) { + port, err := findRandomAvailableUDPPort() + require.NoError(t, err) + require.Greater(t, port, 0) + require.LessOrEqual(t, port, 65535) +} diff --git a/client/internal/routemanager/client/client.go b/client/internal/routemanager/client/client.go index e6ef8b876..c691c54f8 100644 --- a/client/internal/routemanager/client/client.go +++ b/client/internal/routemanager/client/client.go @@ -3,9 +3,8 @@ package client import ( "context" "fmt" - "net" + "net/netip" "reflect" - "strconv" "time" log "github.com/sirupsen/logrus" @@ -566,7 +565,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler { return dnsinterceptor.New(params) case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(params.WgInterface) - dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort())) + dnsAddr := netip.AddrPortFrom(dns.RuntimeIP(), uint16(dns.RuntimePort())) return dynamic.NewRoute(params, dnsAddr) default: return static.NewRoute(params) diff --git a/client/internal/routemanager/client/client_bench_test.go b/client/internal/routemanager/client/client_bench_test.go index 591042ac5..668aec427 100644 --- a/client/internal/routemanager/client/client_bench_test.go +++ b/client/internal/routemanager/client/client_bench_test.go @@ -46,7 +46,7 @@ func generateBenchmarkData(tier benchmarkTier) (*peer.Status, map[route.ID]*rout fqdn := fmt.Sprintf("peer-%d.example.com", i) ip := fmt.Sprintf("10.0.%d.%d", i/256, i%256) - err := statusRecorder.AddPeer(peerKey, fqdn, ip) + err := statusRecorder.AddPeer(peerKey, fqdn, ip, "") if err != nil { panic(fmt.Sprintf("failed to add peer: %v", err)) } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 64f2a8789..e25cc2a5c 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -582,7 +582,7 @@ func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWri if nsNet != nil { reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream) } else { - client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) + client, clientErr := nbdns.GetClientPrivate(d.wgInterface, upstreamIP, dnsTimeout) if clientErr != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr)) return nil diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 8d1398a7a..f0efd7b22 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -50,10 +50,10 @@ type Route struct { cancel context.CancelFunc statusRecorder *peer.Status wgInterface iface.WGIface - resolverAddr string + resolverAddr netip.AddrPort } -func NewRoute(params common.HandlerParams, resolverAddr string) *Route { +func NewRoute(params common.HandlerParams, resolverAddr netip.AddrPort) *Route { return &Route{ route: params.Route, routeRefCounter: params.RouteRefCounter, diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go index 8fed1c8f9..1ae281d56 100644 --- a/client/internal/routemanager/dynamic/route_ios.go +++ b/client/internal/routemanager/dynamic/route_ios.go @@ -17,37 +17,47 @@ import ( const dialTimeout = 10 * time.Second func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { - privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout) + privateClient, err := nbdns.GetClientPrivate(r.wgInterface, r.resolverAddr.Addr(), dialTimeout) if err != nil { return nil, fmt.Errorf("error while creating private client: %s", err) } - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain.PunycodeString()), dns.TypeA) - + fqdn := dns.Fqdn(domain.PunycodeString()) startTime := time.Now() - response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr) - if err != nil { - return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err) - } + var ips []net.IP + var queryErr error - if response.Rcode != dns.RcodeSuccess { - return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode]) - } + for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} { + msg := new(dns.Msg) + msg.SetQuestion(fqdn, qtype) - ips := make([]net.IP, 0) - - for _, answ := range response.Answer { - if aRecord, ok := answ.(*dns.A); ok { - ips = append(ips, aRecord.A) + response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr.String()) + if err != nil { + if queryErr == nil { + queryErr = fmt.Errorf("DNS query for %s (type %d) after %s: %w", domain.SafeString(), qtype, time.Since(startTime), err) + } + continue } - if aaaaRecord, ok := answ.(*dns.AAAA); ok { - ips = append(ips, aaaaRecord.AAAA) + + if response.Rcode != dns.RcodeSuccess { + continue + } + + for _, answ := range response.Answer { + if aRecord, ok := answ.(*dns.A); ok { + ips = append(ips, aRecord.A) + } + if aaaaRecord, ok := answ.(*dns.AAAA); ok { + ips = append(ips, aaaaRecord.AAAA) + } } } if len(ips) == 0 { + if queryErr != nil { + return nil, queryErr + } return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString()) } diff --git a/client/internal/routemanager/fakeip/fakeip.go b/client/internal/routemanager/fakeip/fakeip.go index 1592045d2..5be4ca12e 100644 --- a/client/internal/routemanager/fakeip/fakeip.go +++ b/client/internal/routemanager/fakeip/fakeip.go @@ -1,93 +1,145 @@ package fakeip import ( + "errors" "fmt" "net/netip" "sync" ) -// Manager manages allocation of fake IPs from the 240.0.0.0/8 block -type Manager struct { - mu sync.Mutex - nextIP netip.Addr // Next IP to allocate +var ( + // 240.0.0.1 - 240.255.255.254, block 240.0.0.0/8 (reserved, RFC 1112) + v4Base = netip.AddrFrom4([4]byte{240, 0, 0, 1}) + v4Max = netip.AddrFrom4([4]byte{240, 255, 255, 254}) + v4Block = netip.PrefixFrom(netip.AddrFrom4([4]byte{240, 0, 0, 0}), 8) + + // 0100::1 - 0100::ffff:ffff:ffff:fffe, block 0100::/64 (discard, RFC 6666) + v6Base = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}) + v6Max = netip.AddrFrom16([16]byte{0x01, 0x00, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}) + v6Block = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x01, 0x00}), 64) +) + +// fakeIPPool holds the allocation state for a single address family. +type fakeIPPool struct { + nextIP netip.Addr + baseIP netip.Addr + maxIP netip.Addr + block netip.Prefix allocated map[netip.Addr]netip.Addr // real IP -> fake IP fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP - baseIP netip.Addr // First usable IP: 240.0.0.1 - maxIP netip.Addr // Last usable IP: 240.255.255.254 } -// NewManager creates a new fake IP manager using 240.0.0.0/8 block -func NewManager() *Manager { - baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1}) - maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254}) - - return &Manager{ - nextIP: baseIP, +func newPool(base, maxAddr netip.Addr, block netip.Prefix) *fakeIPPool { + return &fakeIPPool{ + nextIP: base, + baseIP: base, + maxIP: maxAddr, + block: block, allocated: make(map[netip.Addr]netip.Addr), fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: baseIP, - maxIP: maxIP, } } -// AllocateFakeIP allocates a fake IP for the given real IP -// Returns the fake IP, or existing fake IP if already allocated -func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { - if !realIP.Is4() { - return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported") - } - - m.mu.Lock() - defer m.mu.Unlock() - - if fakeIP, exists := m.allocated[realIP]; exists { +// allocate allocates a fake IP for the given real IP. +// Returns the existing fake IP if already allocated. +func (p *fakeIPPool) allocate(realIP netip.Addr) (netip.Addr, error) { + if fakeIP, exists := p.allocated[realIP]; exists { return fakeIP, nil } - startIP := m.nextIP + startIP := p.nextIP for { - currentIP := m.nextIP + currentIP := p.nextIP // Advance to next IP, wrapping at boundary - if m.nextIP.Compare(m.maxIP) >= 0 { - m.nextIP = m.baseIP + if p.nextIP.Compare(p.maxIP) >= 0 { + p.nextIP = p.baseIP } else { - m.nextIP = m.nextIP.Next() + p.nextIP = p.nextIP.Next() } - // Check if current IP is available - if _, inUse := m.fakeToReal[currentIP]; !inUse { - m.allocated[realIP] = currentIP - m.fakeToReal[currentIP] = realIP + if _, inUse := p.fakeToReal[currentIP]; !inUse { + p.allocated[realIP] = currentIP + p.fakeToReal[currentIP] = realIP return currentIP, nil } - // Prevent infinite loop if all IPs exhausted - if m.nextIP.Compare(startIP) == 0 { - return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block") + if p.nextIP.Compare(startIP) == 0 { + return netip.Addr{}, fmt.Errorf("no more fake IPs available in %s block", p.block) } } } -// GetFakeIP returns the fake IP for a real IP if it exists +// Manager manages allocation of fake IPs for dynamic DNS routes. +// IPv4 uses 240.0.0.0/8 (reserved), IPv6 uses 0100::/64 (discard, RFC 6666). +type Manager struct { + mu sync.Mutex + v4 *fakeIPPool + v6 *fakeIPPool +} + +// NewManager creates a new fake IP manager. +func NewManager() *Manager { + return &Manager{ + v4: newPool(v4Base, v4Max, v4Block), + v6: newPool(v6Base, v6Max, v6Block), + } +} + +func (m *Manager) pool(ip netip.Addr) *fakeIPPool { + if ip.Is6() { + return m.v6 + } + return m.v4 +} + +// AllocateFakeIP allocates a fake IP for the given real IP. +func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) { + realIP = realIP.Unmap() + if !realIP.IsValid() { + return netip.Addr{}, errors.New("invalid IP address") + } + + m.mu.Lock() + defer m.mu.Unlock() + + return m.pool(realIP).allocate(realIP) +} + +// GetFakeIP returns the fake IP for a real IP if it exists. func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) { + realIP = realIP.Unmap() + if !realIP.IsValid() { + return netip.Addr{}, false + } + m.mu.Lock() defer m.mu.Unlock() - fakeIP, exists := m.allocated[realIP] - return fakeIP, exists + fakeIP, ok := m.pool(realIP).allocated[realIP] + return fakeIP, ok } -// GetRealIP returns the real IP for a fake IP if it exists, otherwise false +// GetRealIP returns the real IP for a fake IP if it exists. func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) { + fakeIP = fakeIP.Unmap() + if !fakeIP.IsValid() { + return netip.Addr{}, false + } + m.mu.Lock() defer m.mu.Unlock() - realIP, exists := m.fakeToReal[fakeIP] - return realIP, exists + realIP, ok := m.pool(fakeIP).fakeToReal[fakeIP] + return realIP, ok } -// GetFakeIPBlock returns the fake IP block used by this manager +// GetFakeIPBlock returns the v4 fake IP block used by this manager. func (m *Manager) GetFakeIPBlock() netip.Prefix { - return netip.MustParsePrefix("240.0.0.0/8") + return m.v4.block +} + +// GetFakeIPv6Block returns the v6 fake IP block used by this manager. +func (m *Manager) GetFakeIPv6Block() netip.Prefix { + return m.v6.block } diff --git a/client/internal/routemanager/fakeip/fakeip_test.go b/client/internal/routemanager/fakeip/fakeip_test.go index ad3e4bd4e..f554f970d 100644 --- a/client/internal/routemanager/fakeip/fakeip_test.go +++ b/client/internal/routemanager/fakeip/fakeip_test.go @@ -9,16 +9,16 @@ import ( func TestNewManager(t *testing.T) { manager := NewManager() - if manager.baseIP.String() != "240.0.0.1" { - t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String()) + if manager.v4.baseIP.String() != "240.0.0.1" { + t.Errorf("Expected v4 base IP 240.0.0.1, got %s", manager.v4.baseIP.String()) } - if manager.maxIP.String() != "240.255.255.254" { - t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String()) + if manager.v4.maxIP.String() != "240.255.255.254" { + t.Errorf("Expected v4 max IP 240.255.255.254, got %s", manager.v4.maxIP.String()) } - if manager.nextIP.Compare(manager.baseIP) != 0 { - t.Errorf("Expected nextIP to start at baseIP") + if manager.v6.baseIP.String() != "100::1" { + t.Errorf("Expected v6 base IP 100::1, got %s", manager.v6.baseIP.String()) } } @@ -35,7 +35,6 @@ func TestAllocateFakeIP(t *testing.T) { t.Error("Fake IP should be IPv4") } - // Check it's in the correct range if fakeIP.As4()[0] != 240 { t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String()) } @@ -51,13 +50,31 @@ func TestAllocateFakeIP(t *testing.T) { } } -func TestAllocateFakeIPIPv6Rejection(t *testing.T) { +func TestAllocateFakeIPv6(t *testing.T) { manager := NewManager() - realIPv6 := netip.MustParseAddr("2001:db8::1") + realIP := netip.MustParseAddr("2001:db8::1") - _, err := manager.AllocateFakeIP(realIPv6) - if err == nil { - t.Error("Expected error for IPv6 address") + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate fake IPv6: %v", err) + } + + if !fakeIP.Is6() { + t.Error("Fake IP should be IPv6") + } + + if !netip.MustParsePrefix("100::/64").Contains(fakeIP) { + t.Errorf("Fake IP should be in 100::/64 range, got %s", fakeIP.String()) + } + + // Should return same fake IP for same real IP + fakeIP2, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to get existing fake IPv6: %v", err) + } + + if fakeIP.Compare(fakeIP2) != 0 { + t.Errorf("Expected same fake IP, got %s and %s", fakeIP.String(), fakeIP2.String()) } } @@ -65,13 +82,11 @@ func TestGetFakeIP(t *testing.T) { manager := NewManager() realIP := netip.MustParseAddr("1.1.1.1") - // Should not exist initially _, exists := manager.GetFakeIP(realIP) if exists { t.Error("Fake IP should not exist before allocation") } - // Allocate and check expectedFakeIP, err := manager.AllocateFakeIP(realIP) if err != nil { t.Fatalf("Failed to allocate: %v", err) @@ -87,12 +102,30 @@ func TestGetFakeIP(t *testing.T) { } } +func TestGetRealIPv6(t *testing.T) { + manager := NewManager() + realIP := netip.MustParseAddr("2001:db8::1") + + fakeIP, err := manager.AllocateFakeIP(realIP) + if err != nil { + t.Fatalf("Failed to allocate: %v", err) + } + + gotReal, exists := manager.GetRealIP(fakeIP) + if !exists { + t.Error("Real IP should exist for allocated fake IP") + } + + if gotReal.Compare(realIP) != 0 { + t.Errorf("Expected real IP %s, got %s", realIP, gotReal) + } +} + func TestMultipleAllocations(t *testing.T) { manager := NewManager() allocations := make(map[netip.Addr]netip.Addr) - // Allocate multiple IPs for i := 1; i <= 100; i++ { realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) fakeIP, err := manager.AllocateFakeIP(realIP) @@ -100,7 +133,6 @@ func TestMultipleAllocations(t *testing.T) { t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err) } - // Check for duplicates for _, existingFake := range allocations { if fakeIP.Compare(existingFake) == 0 { t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String()) @@ -110,7 +142,6 @@ func TestMultipleAllocations(t *testing.T) { allocations[realIP] = fakeIP } - // Verify all allocations can be retrieved for realIP, expectedFake := range allocations { actualFake, exists := manager.GetFakeIP(realIP) if !exists { @@ -124,11 +155,13 @@ func TestMultipleAllocations(t *testing.T) { func TestGetFakeIPBlock(t *testing.T) { manager := NewManager() - block := manager.GetFakeIPBlock() - expected := "240.0.0.0/8" - if block.String() != expected { - t.Errorf("Expected %s, got %s", expected, block.String()) + if block := manager.GetFakeIPBlock(); block.String() != "240.0.0.0/8" { + t.Errorf("Expected 240.0.0.0/8, got %s", block.String()) + } + + if block := manager.GetFakeIPv6Block(); block.String() != "100::/64" { + t.Errorf("Expected 100::/64, got %s", block.String()) } } @@ -141,7 +174,6 @@ func TestConcurrentAccess(t *testing.T) { var wg sync.WaitGroup results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine) - // Concurrent allocations for i := 0; i < numGoroutines; i++ { wg.Add(1) go func(goroutineID int) { @@ -161,7 +193,6 @@ func TestConcurrentAccess(t *testing.T) { wg.Wait() close(results) - // Check for duplicates seen := make(map[netip.Addr]bool) count := 0 for fakeIP := range results { @@ -178,47 +209,61 @@ func TestConcurrentAccess(t *testing.T) { } func TestIPExhaustion(t *testing.T) { - // Create a manager with limited range for testing manager := &Manager{ - nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - allocated: make(map[netip.Addr]netip.Addr), - fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available + v4: newPool( + netip.AddrFrom4([4]byte{240, 0, 0, 1}), + netip.AddrFrom4([4]byte{240, 0, 0, 3}), + netip.MustParsePrefix("240.0.0.0/8"), + ), + v6: newPool( + netip.MustParseAddr("100::1"), + netip.MustParseAddr("100::3"), + netip.MustParsePrefix("100::/64"), + ), } - // Allocate all available IPs - realIPs := []netip.Addr{ - netip.MustParseAddr("1.0.0.1"), - netip.MustParseAddr("1.0.0.2"), - netip.MustParseAddr("1.0.0.3"), - } - - for _, realIP := range realIPs { - _, err := manager.AllocateFakeIP(realIP) + for _, realIP := range []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"} { + _, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP)) if err != nil { t.Fatalf("Failed to allocate fake IP: %v", err) } } - // Try to allocate one more - should fail _, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4")) if err == nil { - t.Error("Expected exhaustion error") + t.Error("Expected v4 exhaustion error") + } + + // Same for v6 + for _, realIP := range []string{"2001:db8::1", "2001:db8::2", "2001:db8::3"} { + _, err := manager.AllocateFakeIP(netip.MustParseAddr(realIP)) + if err != nil { + t.Fatalf("Failed to allocate fake IPv6: %v", err) + } + } + + _, err = manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::4")) + if err == nil { + t.Error("Expected v6 exhaustion error") } } func TestWrapAround(t *testing.T) { - // Create manager starting near the end of range manager := &Manager{ - nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), - allocated: make(map[netip.Addr]netip.Addr), - fakeToReal: make(map[netip.Addr]netip.Addr), - baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}), - maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}), + v4: newPool( + netip.AddrFrom4([4]byte{240, 0, 0, 1}), + netip.AddrFrom4([4]byte{240, 0, 0, 254}), + netip.MustParsePrefix("240.0.0.0/8"), + ), + v6: newPool( + netip.MustParseAddr("100::1"), + netip.MustParseAddr("100::ffff:ffff:ffff:fffe"), + netip.MustParsePrefix("100::/64"), + ), } + // Start near the end + manager.v4.nextIP = netip.AddrFrom4([4]byte{240, 0, 0, 254}) - // Allocate the last IP fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1")) if err != nil { t.Fatalf("Failed to allocate first IP: %v", err) @@ -228,7 +273,6 @@ func TestWrapAround(t *testing.T) { t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String()) } - // Next allocation should wrap around to the beginning fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2")) if err != nil { t.Fatalf("Failed to allocate second IP: %v", err) @@ -238,3 +282,32 @@ func TestWrapAround(t *testing.T) { t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String()) } } + +func TestMixedV4V6(t *testing.T) { + manager := NewManager() + + v4Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("8.8.8.8")) + if err != nil { + t.Fatalf("Failed to allocate v4: %v", err) + } + + v6Fake, err := manager.AllocateFakeIP(netip.MustParseAddr("2001:db8::1")) + if err != nil { + t.Fatalf("Failed to allocate v6: %v", err) + } + + if !v4Fake.Is4() || !v6Fake.Is6() { + t.Errorf("Wrong families: v4=%s v6=%s", v4Fake, v6Fake) + } + + // Reverse lookups should work for both + gotV4, ok := manager.GetRealIP(v4Fake) + if !ok || gotV4.String() != "8.8.8.8" { + t.Errorf("v4 reverse lookup failed: got %s, ok=%v", gotV4, ok) + } + + gotV6, ok := manager.GetRealIP(v6Fake) + if !ok || gotV6.String() != "2001:db8::1" { + t.Errorf("v6 reverse lookup failed: got %s, ok=%v", gotV6, ok) + } +} diff --git a/client/internal/routemanager/ipfwdstate/ipfwdstate.go b/client/internal/routemanager/ipfwdstate/ipfwdstate.go index da81c18f9..2be1c2ae7 100644 --- a/client/internal/routemanager/ipfwdstate/ipfwdstate.go +++ b/client/internal/routemanager/ipfwdstate/ipfwdstate.go @@ -9,7 +9,11 @@ import ( ) // IPForwardingState is a struct that keeps track of the IP forwarding state. -// todo: read initial state of the IP forwarding from the system and reset the state based on it +// todo: read initial state of the IP forwarding from the system and reset the state based on it. +// todo: separate v4/v6 forwarding state, since the sysctls are independent +// (net.ipv4.ip_forward vs net.ipv6.conf.all.forwarding). Currently the nftables +// manager shares one instance between both routers, which works only because +// EnableIPForwarding enables both sysctls in a single call. type IPForwardingState struct { enabledCounter int } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 3923e153b..e5d9363ca 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -159,16 +159,24 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { if config.DNSFeatureFlag { m.fakeIPManager = fakeip.NewManager() - id := uuid.NewString() + v4ID := uuid.NewString() fakeIPRoute := &route.Route{ - ID: route.ID(id), + ID: route.ID(v4ID), Network: m.fakeIPManager.GetFakeIPBlock(), - NetID: route.NetID(id), + NetID: route.NetID(v4ID), Peer: m.pubKey, NetworkType: route.IPv4Network, } - cr = append(cr, fakeIPRoute) - m.notifier.SetFakeIPRoute(fakeIPRoute) + v6ID := uuid.NewString() + fakeIPv6Route := &route.Route{ + ID: route.ID(v6ID), + Network: m.fakeIPManager.GetFakeIPv6Block(), + NetID: route.NetID(v6ID), + Peer: m.pubKey, + NetworkType: route.IPv6Network, + } + cr = append(cr, fakeIPRoute, fakeIPv6Route) + m.notifier.SetFakeIPRoutes([]*route.Route{fakeIPRoute, fakeIPv6Route}) } m.notifier.SetInitialClientRoutes(cr, routesForComparison) diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 3697545ae..926f06bc9 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/route" ) @@ -409,7 +410,7 @@ func TestManagerUpdateRoutes(t *testing.T) { } opts := iface.WGIFaceOpts{ IFaceName: fmt.Sprintf("utun43%d", n), - Address: "100.65.65.2/24", + Address: wgaddr.MustParseWGAddress("100.65.65.2/24"), WGPort: 33100, WGPrivKey: peerPrivateKey.String(), MTU: iface.DefaultMTU, diff --git a/client/internal/routemanager/notifier/notifier_android.go b/client/internal/routemanager/notifier/notifier_android.go index 55e0b7421..140a583f7 100644 --- a/client/internal/routemanager/notifier/notifier_android.go +++ b/client/internal/routemanager/notifier/notifier_android.go @@ -16,7 +16,7 @@ import ( type Notifier struct { initialRoutes []*route.Route currentRoutes []*route.Route - fakeIPRoute *route.Route + fakeIPRoutes []*route.Route listener listener.NetworkChangeListener listenerMux sync.Mutex @@ -38,9 +38,9 @@ func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesFo n.currentRoutes = filterStatic(routesForComparison) } -// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild. -func (n *Notifier) SetFakeIPRoute(r *route.Route) { - n.fakeIPRoute = r +// SetFakeIPRoutes stores the fake IP routes to be included in every TUN rebuild. +func (n *Notifier) SetFakeIPRoutes(routes []*route.Route) { + n.fakeIPRoutes = routes } func (n *Notifier) OnNewRoutes(idMap route.HAMap) { @@ -74,14 +74,12 @@ func (n *Notifier) notify() { } allRoutes := slices.Clone(n.currentRoutes) - if n.fakeIPRoute != nil { - allRoutes = append(allRoutes, n.fakeIPRoute) - } + allRoutes = append(allRoutes, n.fakeIPRoutes...) routeStrings := n.routesToStrings(allRoutes) sort.Strings(routeStrings) go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, allRoutes), ",")) + l.OnNetworkChanged(strings.Join(routeStrings, ",")) }(n.listener) } @@ -119,14 +117,5 @@ func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool { func (n *Notifier) GetInitialRouteRanges() []string { initialStrings := n.routesToStrings(n.initialRoutes) sort.Strings(initialStrings) - return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes) -} - -func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string { - for _, r := range routes { - if r.Network.Addr().Is4() && r.Network.Bits() == 0 { - return append(slices.Clone(inputRanges), "::/0") - } - } - return inputRanges + return initialStrings } diff --git a/client/internal/routemanager/notifier/notifier_ios.go b/client/internal/routemanager/notifier/notifier_ios.go index 68c85067a..27a2a722d 100644 --- a/client/internal/routemanager/notifier/notifier_ios.go +++ b/client/internal/routemanager/notifier/notifier_ios.go @@ -34,7 +34,7 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { // iOS doesn't care about initial routes } -func (n *Notifier) SetFakeIPRoute(*route.Route) { +func (n *Notifier) SetFakeIPRoutes([]*route.Route) { // Not used on iOS } @@ -65,19 +65,10 @@ func (n *Notifier) notify() { } go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ",")) + l.OnNetworkChanged(strings.Join(n.currentPrefixes, ",")) }(n.listener) } func (n *Notifier) GetInitialRouteRanges() []string { return nil } - -func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string { - for _, r := range inputRanges { - if r == "0.0.0.0/0" { - return append(slices.Clone(inputRanges), "::/0") - } - } - return inputRanges -} diff --git a/client/internal/routemanager/notifier/notifier_other.go b/client/internal/routemanager/notifier/notifier_other.go index 97c815cf0..f57cadb0b 100644 --- a/client/internal/routemanager/notifier/notifier_other.go +++ b/client/internal/routemanager/notifier/notifier_other.go @@ -23,7 +23,7 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) { // Not used on non-mobile platforms } -func (n *Notifier) SetFakeIPRoute(*route.Route) { +func (n *Notifier) SetFakeIPRoutes([]*route.Route) { // Not used on non-mobile platforms } diff --git a/client/internal/routemanager/server/server.go b/client/internal/routemanager/server/server.go index e674c80cd..f569c0cac 100644 --- a/client/internal/routemanager/server/server.go +++ b/client/internal/routemanager/server/server.go @@ -21,6 +21,7 @@ type Router struct { firewall firewall.Manager wgInterface iface.WGIface statusRecorder *peer.Status + useNewDNSRoute bool } func NewRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*Router, error) { @@ -37,6 +38,8 @@ func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRout r.mux.Lock() defer r.mux.Unlock() + prevUseNewDNSRoute := r.useNewDNSRoute + serverRoutesToRemove := make([]route.ID, 0) for routeID := range r.routes { @@ -48,7 +51,7 @@ func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRout for _, routeID := range serverRoutesToRemove { oldRoute := r.routes[routeID] - err := r.removeFromServerNetwork(oldRoute) + err := r.removeFromServerNetwork(oldRoute, prevUseNewDNSRoute) if err != nil { log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v", oldRoute.ID, oldRoute.Network, err) @@ -56,6 +59,8 @@ func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRout delete(r.routes, routeID) } + r.useNewDNSRoute = useNewDNSRoute + // If routing is to be disabled, do it after routes have been removed // If routing is to be enabled, do it before adding new routes; addToServerNetwork needs routing to be enabled if len(routesMap) > 0 { @@ -85,13 +90,13 @@ func (r *Router) UpdateRoutes(routesMap map[route.ID]*route.Route, useNewDNSRout return nil } -func (r *Router) removeFromServerNetwork(route *route.Route) error { +func (r *Router) removeFromServerNetwork(route *route.Route, useNewDNSRoute bool) error { if r.ctx.Err() != nil { log.Infof("Not removing from server network because context is done") return r.ctx.Err() } - routerPair := routeToRouterPair(route, false) + routerPair := routeToRouterPair(route, useNewDNSRoute) if err := r.firewall.RemoveNatRule(routerPair); err != nil { return fmt.Errorf("remove routing rules: %w", err) } @@ -124,7 +129,7 @@ func (r *Router) CleanUp() { defer r.mux.Unlock() for _, route := range r.routes { - routerPair := routeToRouterPair(route, false) + routerPair := routeToRouterPair(route, r.useNewDNSRoute) if err := r.firewall.RemoveNatRule(routerPair); err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } @@ -146,8 +151,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP if useNewDNSRoute { destination.Set = firewall.NewDomainSet(route.Domains) } else { - // TODO: add ipv6 additionally - destination = getDefaultPrefix(destination.Prefix) + destination = getDefaultPrefix(route.Network) } } else { destination.Prefix = route.Network.Masked() @@ -158,6 +162,7 @@ func routeToRouterPair(route *route.Route, useNewDNSRoute bool) firewall.RouterP Source: source, Destination: destination, Masquerade: route.Masquerade, + Dynamic: route.IsDynamic(), } } diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index c0ca21d22..165448b60 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -107,8 +107,16 @@ func (r *SysOps) validateRoute(prefix netip.Prefix) error { addr.IsInterfaceLocalMulticast(), addr.IsMulticast(), addr.IsUnspecified() && prefix.Bits() != 0, - r.wgInterface.Address().Network.Contains(addr): + r.isOwnAddress(addr): return vars.ErrRouteNotAllowed } return nil } + +func (r *SysOps) isOwnAddress(addr netip.Addr) bool { + if r.wgInterface == nil { + return false + } + wgAddr := r.wgInterface.Address() + return wgAddr.Network.Contains(addr) || (wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr)) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index bf7b95a28..2b96c14dc 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -221,30 +221,20 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er return err } - // TODO: remove once IPv6 is supported on the interface - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) + // When the interface has no v6, add v6 split-default as blackhole so + // unroutable v6 goes to WG (dropped, no AllowedIPs) instead of leaking + // to the system default route. When v6 is active, management sends ::/0 + // as a separate route that the dedicated handler adds. + // Soft-fail: v6 blackhole is best-effort, don't abort v4 routing on failure. + if !r.wgInterface.Address().HasIPv6() { + if err := r.addV6SplitDefault(nextHop); err != nil { + log.Warnf("failed to add v6 split-default blackhole: %s", err) } - return fmt.Errorf("add unreachable route split 2: %w", err) } return nil case vars.Defaultv6: - if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { - if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } - - return nil + return r.addV6SplitDefault(nextHop) } return r.addToRouteTable(prefix, nextHop) @@ -265,30 +255,42 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) result = multierror.Append(result, err) } - // TODO: remove once IPv6 is supported on the interface - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) + if !r.wgInterface.Address().HasIPv6() { + result = multierror.Append(result, r.removeV6SplitDefault(nextHop)) } return nberrors.FormatErrorOrNil(result) case vars.Defaultv6: - var result *multierror.Error - if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { - result = multierror.Append(result, err) - } - if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { - result = multierror.Append(result, err) - } - - return nberrors.FormatErrorOrNil(result) + return nberrors.FormatErrorOrNil(r.removeV6SplitDefault(nextHop)) default: return r.removeFromRouteTable(prefix, nextHop) } } +func (r *SysOps) addV6SplitDefault(nextHop Nexthop) error { + if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { + return fmt.Errorf("add split 1: %w", err) + } + if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil { + if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil { + log.Warnf("Failed to rollback v6 split-default: %s", err2) + } + return fmt.Errorf("add split 2: %w", err) + } + return nil +} + +func (r *SysOps) removeV6SplitDefault(nextHop Nexthop) *multierror.Error { + var result *multierror.Error + if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { + result = multierror.Append(result, err) + } + if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil { + result = multierror.Append(result, err) + } + return result +} + func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 08e354a78..5695c40c3 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -21,6 +21,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/routemanager/vars" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -445,7 +446,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen opts := iface.WGIFaceOpts{ IFaceName: interfaceName, - Address: ipAddressCIDR, + Address: wgaddr.MustParseWGAddress(ipAddressCIDR), WGPrivKey: peerPrivateKey.String(), WGPort: listenPort, MTU: iface.DefaultMTU, diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 39a9fd978..8c6b7d9a9 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -53,6 +53,8 @@ const ( // ipv4ForwardingPath is the path to the file containing the IP forwarding setting. ipv4ForwardingPath = "net.ipv4.ip_forward" + // ipv6ForwardingPath is the path to the file containing the IPv6 forwarding setting. + ipv6ForwardingPath = "net.ipv6.conf.all.forwarding" ) var ErrTableIDExists = errors.New("ID exists with different name") @@ -185,10 +187,11 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 - // TODO remove this once we have ipv6 support - if prefix == vars.Defaultv4 { + // When the peer has no IPv6, blackhole v6 to prevent leaking. + // When IPv6 is enabled, management sends ::/0 as a separate route. + if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) { if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("add blackhole: %w", err) + return fmt.Errorf("add v6 blackhole: %w", err) } } if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { @@ -206,10 +209,9 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error return r.genericRemoveVPNRoute(prefix, intf) } - // TODO remove this once we have ipv6 support - if prefix == vars.Defaultv4 { + if prefix == vars.Defaultv4 && (r.wgInterface == nil || !r.wgInterface.Address().HasIPv6()) { if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil { - return fmt.Errorf("remove unreachable route: %w", err) + log.Debugf("remove v6 blackhole: %v", err) } } if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil { @@ -762,8 +764,13 @@ func flushRoutes(tableID, family int) error { } func EnableIPForwarding() error { - _, err := sysctl.Set(ipv4ForwardingPath, 1, false) - return err + if _, err := sysctl.Set(ipv4ForwardingPath, 1, false); err != nil { + return err + } + if _, err := sysctl.Set(ipv6ForwardingPath, 1, false); err != nil { + log.Warnf("failed to enable IPv6 forwarding: %v", err) + } + return nil } // entryExists checks if the specified ID or name already exists in the rt_tables file diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index a616f9533..33f5ab1b0 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -50,10 +50,11 @@ type CustomLogger interface { } type selectRoute struct { - NetID string - Network netip.Prefix - Domains domain.List - Selected bool + NetID string + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } func init() { @@ -198,6 +199,7 @@ func (c *Client) GetStatusDetails() *StatusDetails { } pi := PeerInfo{ IP: p.IP, + IPv6: p.IPv6, FQDN: p.FQDN, LocalIceCandidateEndpoint: p.LocalIceCandidateEndpoint, RemoteIceCandidateEndpoint: p.RemoteIceCandidateEndpoint, @@ -216,7 +218,7 @@ func (c *Client) GetStatusDetails() *StatusDetails { } peerInfos[n] = pi } - return &StatusDetails{items: peerInfos, fqdn: fullStatus.LocalPeerState.FQDN, ip: fullStatus.LocalPeerState.IP} + return &StatusDetails{items: peerInfos, fqdn: fullStatus.LocalPeerState.FQDN, ip: fullStatus.LocalPeerState.IP, ipv6: fullStatus.LocalPeerState.IPv6} } // SetConnectionListener set the network connection listener @@ -366,48 +368,60 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } routeManager := engine.GetRouteManager() - routesMap := routeManager.GetClientRoutesWithNetID() if routeManager == nil { return nil, fmt.Errorf("could not get route manager") } + routesMap := routeManager.GetClientRoutesWithNetID() routeSelector := routeManager.GetRouteSelector() if routeSelector == nil { return nil, fmt.Errorf("could not get route selector") } + v6ExitMerged := route.V6ExitMergeSet(routesMap) + routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged) + resolvedDomains := c.recorder.GetResolvedDomainsStates() + + return prepareRouteSelectionDetails(routes, resolvedDomains), nil +} + +func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute { var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + if _, ok := v6Merged[id]; ok { + continue + } + + r := &selectRoute{ NetID: string(id), Network: rt[0].Network, Domains: rt[0].Domains, - Selected: routeSelector.IsSelected(id), + Selected: isSelected(id), } - routes = append(routes, route) + + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6Merged[v6ID]; ok { + r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { - iPrefix := routes[i].Network.Bits() - jPrefix := routes[j].Network.Bits() - - if iPrefix == jPrefix { - iAddr := routes[i].Network.Addr() - jAddr := routes[j].Network.Addr() - if iAddr == jAddr { - return routes[i].NetID < routes[j].NetID - } - return iAddr.String() < jAddr.String() + iBits, jBits := routes[i].Network.Bits(), routes[j].Network.Bits() + if iBits != jBits { + return iBits < jBits } - return iPrefix < jPrefix + iAddr, jAddr := routes[i].Network.Addr(), routes[j].Network.Addr() + if iAddr != jAddr { + return iAddr.Less(jAddr) + } + return routes[i].NetID < routes[j].NetID }) - resolvedDomains := c.recorder.GetResolvedDomainsStates() - - return prepareRouteSelectionDetails(routes, resolvedDomains), nil - + return routes } func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { @@ -443,6 +457,9 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom if len(r.Domains) > 0 { netStr = r.Domains.SafeString() } + for _, extra := range r.extraNetworks { + netStr += ", " + extra.String() + } routeSelection = append(routeSelection, RoutesSelectionInfo{ ID: r.NetID, @@ -474,7 +491,9 @@ func (c *Client) SelectRoute(id string) error { } else { log.Debugf("select route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routesMap)); err != nil { log.Debugf("error when selecting routes: %s", err) return fmt.Errorf("select routes: %w", err) } @@ -501,7 +520,9 @@ func (c *Client) DeselectRoute(id string) error { } else { log.Debugf("deselect route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + if err := routeSelector.DeselectRoutes(routes, maps.Keys(routesMap)); err != nil { log.Debugf("error when deselecting routes: %s", err) return fmt.Errorf("deselect routes: %w", err) } diff --git a/client/ios/NetBirdSDK/peer_notifier.go b/client/ios/NetBirdSDK/peer_notifier.go index 9b00568be..025cd94cd 100644 --- a/client/ios/NetBirdSDK/peer_notifier.go +++ b/client/ios/NetBirdSDK/peer_notifier.go @@ -5,6 +5,7 @@ package NetBirdSDK // PeerInfo describe information about the peers. It designed for the UI usage type PeerInfo struct { IP string + IPv6 string FQDN string LocalIceCandidateEndpoint string RemoteIceCandidateEndpoint string @@ -23,6 +24,11 @@ type PeerInfo struct { Routes RoutesDetails } +// GetIPv6 returns the IPv6 address of the peer +func (p PeerInfo) GetIPv6() string { + return p.IPv6 +} + // GetRoutes return with RouteDetails func (p PeerInfo) GetRouteDetails() *RoutesDetails { return &p.Routes @@ -57,6 +63,7 @@ type StatusDetails struct { items []PeerInfo fqdn string ip string + ipv6 string } // Add new PeerInfo to the collection @@ -100,3 +107,8 @@ func (array StatusDetails) GetFQDN() string { func (array StatusDetails) GetIP() string { return array.ip } + +// GetIPv6 return with the IPv6 of the local peer +func (array StatusDetails) GetIPv6() string { + return array.ipv6 +} diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index c26a6decd..ed49ccddb 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -110,6 +110,24 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) { return cfg.RosenpassPermissive, err } +// GetDisableIPv6 reads disable IPv6 setting from config file +func (p *Preferences) GetDisableIPv6() (bool, error) { + if p.configInput.DisableIPv6 != nil { + return *p.configInput.DisableIPv6, nil + } + + cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return false, err + } + return cfg.DisableIPv6, err +} + +// SetDisableIPv6 stores the given value and waits for commit +func (p *Preferences) SetDisableIPv6(disable bool) { + p.configInput.DisableIPv6 = &disable +} + // Commit write out the changes into config file func (p *Preferences) Commit() error { // Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename) diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 11e7877f2..2c054c99a 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -342,6 +342,7 @@ type LoginRequest struct { EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + DisableIpv6 *bool `protobuf:"varint,40,opt,name=disable_ipv6,json=disableIpv6,proto3,oneof" json:"disable_ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -650,6 +651,13 @@ func (x *LoginRequest) GetSshJWTCacheTTL() int32 { return 0 } +func (x *LoginRequest) GetDisableIpv6() bool { + if x != nil && x.DisableIpv6 != nil { + return *x.DisableIpv6 + } + return false +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -1182,6 +1190,7 @@ type GetConfigResponse struct { EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"` + DisableIpv6 bool `protobuf:"varint,27,opt,name=disable_ipv6,json=disableIpv6,proto3" json:"disable_ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1398,6 +1407,13 @@ func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 { return 0 } +func (x *GetConfigResponse) GetDisableIpv6() bool { + if x != nil { + return x.DisableIpv6 + } + return false +} + // PeerState contains the latest state of a peer type PeerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1419,6 +1435,7 @@ type PeerState struct { Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` SshHostKey []byte `protobuf:"bytes,19,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"` + Ipv6 string `protobuf:"bytes,20,opt,name=ipv6,proto3" json:"ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1579,6 +1596,13 @@ func (x *PeerState) GetSshHostKey() []byte { return nil } +func (x *PeerState) GetIpv6() string { + if x != nil { + return x.Ipv6 + } + return "" +} + // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1589,6 +1613,7 @@ type LocalPeerState struct { RosenpassEnabled bool `protobuf:"varint,5,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` Networks []string `protobuf:"bytes,7,rep,name=networks,proto3" json:"networks,omitempty"` + Ipv6 string `protobuf:"bytes,8,opt,name=ipv6,proto3" json:"ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1672,6 +1697,13 @@ func (x *LocalPeerState) GetNetworks() []string { return nil } +func (x *LocalPeerState) GetIpv6() string { + if x != nil { + return x.Ipv6 + } + return "" +} + // SignalState contains the latest state of a signal connection type SignalState struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -4009,6 +4041,7 @@ type SetConfigRequest struct { EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"` SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"` + DisableIpv6 *bool `protobuf:"varint,35,opt,name=disable_ipv6,json=disableIpv6,proto3,oneof" json:"disable_ipv6,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -4281,6 +4314,13 @@ func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 { return 0 } +func (x *SetConfigRequest) GetDisableIpv6() bool { + if x != nil && x.DisableIpv6 != nil { + return *x.DisableIpv6 + } + return false +} + type SetConfigResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -6186,7 +6226,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xb6\x12\n" + + "\fEmptyRequest\"\xef\x12\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -6230,7 +6270,8 @@ const file_daemon_proto_rawDesc = "" + "\x1cenableSSHLocalPortForwarding\x18$ \x01(\bH\x17R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + "\x1denableSSHRemotePortForwarding\x18% \x01(\bH\x18R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + "\x0edisableSSHAuth\x18& \x01(\bH\x19R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + - "\x0esshJWTCacheTTL\x18' \x01(\x05H\x1aR\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + + "\x0esshJWTCacheTTL\x18' \x01(\x05H\x1aR\x0esshJWTCacheTTL\x88\x01\x01\x12&\n" + + "\fdisable_ipv6\x18( \x01(\bH\x1bR\vdisableIpv6\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -6257,7 +6298,8 @@ const file_daemon_proto_rawDesc = "" + "\x1d_enableSSHLocalPortForwardingB \n" + "\x1e_enableSSHRemotePortForwardingB\x11\n" + "\x0f_disableSSHAuthB\x11\n" + - "\x0f_sshJWTCacheTTL\"\xb5\x01\n" + + "\x0f_sshJWTCacheTTLB\x0f\n" + + "\r_disable_ipv6\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + @@ -6290,7 +6332,7 @@ const file_daemon_proto_rawDesc = "" + "\fDownResponse\"P\n" + "\x10GetConfigRequest\x12 \n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + - "\busername\x18\x02 \x01(\tR\busername\"\xdb\b\n" + + "\busername\x18\x02 \x01(\tR\busername\"\xfe\b\n" + "\x11GetConfigResponse\x12$\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\n" + @@ -6321,7 +6363,8 @@ const file_daemon_proto_rawDesc = "" + "\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" + "\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" + "\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" + - "\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\"\xfe\x05\n" + + "\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\x12!\n" + + "\fdisable_ipv6\x18\x1b \x01(\bR\vdisableIpv6\"\x92\x06\n" + "\tPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" + @@ -6345,7 +6388,8 @@ const file_daemon_proto_rawDesc = "" + "\frelayAddress\x18\x12 \x01(\tR\frelayAddress\x12\x1e\n" + "\n" + "sshHostKey\x18\x13 \x01(\fR\n" + - "sshHostKey\"\xf0\x01\n" + + "sshHostKey\x12\x12\n" + + "\x04ipv6\x18\x14 \x01(\tR\x04ipv6\"\x84\x02\n" + "\x0eLocalPeerState\x12\x0e\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" + @@ -6353,7 +6397,8 @@ const file_daemon_proto_rawDesc = "" + "\x04fqdn\x18\x04 \x01(\tR\x04fqdn\x12*\n" + "\x10rosenpassEnabled\x18\x05 \x01(\bR\x10rosenpassEnabled\x120\n" + "\x13rosenpassPermissive\x18\x06 \x01(\bR\x13rosenpassPermissive\x12\x1a\n" + - "\bnetworks\x18\a \x03(\tR\bnetworks\"S\n" + + "\bnetworks\x18\a \x03(\tR\bnetworks\x12\x12\n" + + "\x04ipv6\x18\b \x01(\tR\x04ipv6\"S\n" + "\vSignalState\x12\x10\n" + "\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" + "\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" + @@ -6534,7 +6579,7 @@ const file_daemon_proto_rawDesc = "" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + "\t_username\"\x17\n" + - "\x15SwitchProfileResponse\"\xdf\x10\n" + + "\x15SwitchProfileResponse\"\x98\x11\n" + "\x10SetConfigRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + @@ -6573,7 +6618,8 @@ const file_daemon_proto_rawDesc = "" + "\x1cenableSSHLocalPortForwarding\x18\x1f \x01(\bH\x14R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + "\x1denableSSHRemotePortForwarding\x18 \x01(\bH\x15R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" + "\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01\x12+\n" + - "\x0esshJWTCacheTTL\x18\" \x01(\x05H\x17R\x0esshJWTCacheTTL\x88\x01\x01B\x13\n" + + "\x0esshJWTCacheTTL\x18\" \x01(\x05H\x17R\x0esshJWTCacheTTL\x88\x01\x01\x12&\n" + + "\fdisable_ipv6\x18# \x01(\bH\x18R\vdisableIpv6\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -6597,7 +6643,8 @@ const file_daemon_proto_rawDesc = "" + "\x1d_enableSSHLocalPortForwardingB \n" + "\x1e_enableSSHRemotePortForwardingB\x11\n" + "\x0f_disableSSHAuthB\x11\n" + - "\x0f_sshJWTCacheTTL\"\x13\n" + + "\x0f_sshJWTCacheTTLB\x0f\n" + + "\r_disable_ipv6\"\x13\n" + "\x11SetConfigResponse\"Q\n" + "\x11AddProfileRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 3fee9eca8..dedff43e2 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -204,6 +204,7 @@ message LoginRequest { optional bool enableSSHRemotePortForwarding = 37; optional bool disableSSHAuth = 38; optional int32 sshJWTCacheTTL = 39; + optional bool disable_ipv6 = 40; } message LoginResponse { @@ -311,6 +312,8 @@ message GetConfigResponse { bool disableSSHAuth = 25; int32 sshJWTCacheTTL = 26; + + bool disable_ipv6 = 27; } // PeerState contains the latest state of a peer @@ -333,6 +336,7 @@ message PeerState { google.protobuf.Duration latency = 17; string relayAddress = 18; bytes sshHostKey = 19; + string ipv6 = 20; } // LocalPeerState contains the latest state of the local peer @@ -344,6 +348,7 @@ message LocalPeerState { bool rosenpassEnabled = 5; bool rosenpassPermissive = 6; repeated string networks = 7; + string ipv6 = 8; } // SignalState contains the latest state of a signal connection @@ -672,6 +677,7 @@ message SetConfigRequest { optional bool enableSSHRemotePortForwarding = 32; optional bool disableSSHAuth = 33; optional int32 sshJWTCacheTTL = 34; + optional bool disable_ipv6 = 35; } message SetConfigResponse{} diff --git a/client/server/network.go b/client/server/network.go index 76c5af40e..12cefbd9c 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -18,10 +18,11 @@ import ( ) type selectRoute struct { - NetID route.NetID - Network netip.Prefix - Domains domain.List - Selected bool + NetID route.NetID + Network netip.Prefix + Domains domain.List + Selected bool + extraNetworks []netip.Prefix } // ListNetworks returns a list of all available networks. @@ -50,18 +51,32 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro routesMap := routeMgr.GetClientRoutesWithNetID() routeSelector := routeMgr.GetRouteSelector() + v6ExitMerged := route.V6ExitMergeSet(routesMap) + var routes []*selectRoute for id, rt := range routesMap { if len(rt) == 0 { continue } - route := &selectRoute{ + // Skip v6 exit nodes that are merged into their v4 counterpart. + if _, ok := v6ExitMerged[id]; ok { + continue + } + + r := &selectRoute{ NetID: id, Network: rt[0].Network, Domains: rt[0].Domains, Selected: routeSelector.IsSelected(id), } - routes = append(routes, route) + + // Merge paired v6 exit node prefix into this entry. + v6ID := route.NetID(string(id) + route.V6ExitSuffix) + if _, ok := v6ExitMerged[v6ID]; ok && len(routesMap[v6ID]) > 0 { + r.extraNetworks = []netip.Prefix{routesMap[v6ID][0].Network} + } + + routes = append(routes, r) } sort.Slice(routes, func(i, j int) bool { @@ -82,9 +97,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() var pbRoutes []*proto.Network for _, route := range routes { + rangeStr := route.Network.String() + for _, extra := range route.extraNetworks { + rangeStr += ", " + extra.String() + } pbRoute := &proto.Network{ ID: string(route.NetID), - Range: route.Network.String(), + Range: rangeStr, Domains: route.Domains.ToSafeStringList(), ResolvedIPs: map[string]*proto.IPList{}, Selected: route.Selected, @@ -147,7 +166,9 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ routeSelector.SelectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } @@ -197,7 +218,9 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe routeSelector.DeselectAllRoutes() } else { routes := toNetIDs(req.GetNetworkIDs()) - netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + routesMap := routeManager.GetClientRoutesWithNetID() + routes = route.ExpandV6ExitPairs(routes, routesMap) + netIdRoutes := maps.Keys(routesMap) if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } diff --git a/client/server/server.go b/client/server/server.go index 648ffa8ce..bc8de8f9f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -385,6 +385,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.DisableNotifications = msg.DisableNotifications config.LazyConnectionEnabled = msg.LazyConnectionEnabled config.BlockInbound = msg.BlockInbound + config.DisableIPv6 = msg.DisableIpv6 config.EnableSSHRoot = msg.EnableSSHRoot config.EnableSSHSFTP = msg.EnableSSHSFTP config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding @@ -1483,6 +1484,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p disableDNS := cfg.DisableDNS disableClientRoutes := cfg.DisableClientRoutes disableServerRoutes := cfg.DisableServerRoutes + disableIPv6 := cfg.DisableIPv6 blockLANAccess := cfg.BlockLANAccess enableSSHRoot := false @@ -1533,6 +1535,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p DisableDns: disableDNS, DisableClientRoutes: disableClientRoutes, DisableServerRoutes: disableServerRoutes, + DisableIpv6: disableIPv6, BlockLanAccess: blockLANAccess, EnableSSHRoot: enableSSHRoot, EnableSSHSFTP: enableSSHSFTP, diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index b90b5653d..553d4ad71 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -71,6 +71,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { disableNotifications := true lazyConnectionEnabled := true blockInbound := true + disableIPv6 := true mtu := int64(1280) sshJWTCacheTTL := int32(300) @@ -95,6 +96,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { DisableNotifications: &disableNotifications, LazyConnectionEnabled: &lazyConnectionEnabled, BlockInbound: &blockInbound, + DisableIpv6: &disableIPv6, NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, CleanNATExternalIPs: false, CustomDNSAddress: []byte("1.1.1.1:53"), @@ -140,6 +142,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.Equal(t, disableNotifications, *cfg.DisableNotifications) require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled) require.Equal(t, blockInbound, cfg.BlockInbound) + require.Equal(t, disableIPv6, cfg.DisableIPv6) require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs) require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress) // IFaceBlackList contains defaults + extras @@ -189,6 +192,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { "DisableNotifications": true, "LazyConnectionEnabled": true, "BlockInbound": true, + "DisableIpv6": true, "NatExternalIPs": true, "CustomDNSAddress": true, "ExtraIFaceBlacklist": true, @@ -247,6 +251,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) { "disable-firewall": "DisableFirewall", "block-lan-access": "BlockLanAccess", "block-inbound": "BlockInbound", + "disable-ipv6": "DisableIpv6", "enable-lazy-connection": "LazyConnectionEnabled", "external-ip-map": "NatExternalIPs", "dns-resolver-address": "CustomDNSAddress", diff --git a/client/server/trace.go b/client/server/trace.go index e4ac91487..7fea31c49 100644 --- a/client/server/trace.go +++ b/client/server/trace.go @@ -24,14 +24,9 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( return nil, err } - srcAddr, err := s.parseAddress(req.GetSourceIp(), engine) + srcAddr, dstAddr, err := s.resolveTraceAddresses(req.GetSourceIp(), req.GetDestinationIp(), engine) if err != nil { - return nil, fmt.Errorf("invalid source IP address: %w", err) - } - - dstAddr, err := s.parseAddress(req.GetDestinationIp(), engine) - if err != nil { - return nil, fmt.Errorf("invalid destination IP address: %w", err) + return nil, err } protocol, err := s.parseProtocol(req.GetProtocol()) @@ -89,16 +84,73 @@ func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) { return tracer, engine, nil } -func (s *Server) parseAddress(addr string, engine *internal.Engine) (netip.Addr, error) { - if addr == "self" { - return engine.GetWgAddr(), nil +// resolveTraceAddresses parses src/dst, resolving "self" to the local overlay +// address matching the peer's address family. +func (s *Server) resolveTraceAddresses(src, dst string, engine *internal.Engine) (netip.Addr, netip.Addr, error) { + srcSelf := src == "self" + dstSelf := dst == "self" + + if srcSelf && dstSelf { + return netip.Addr{}, netip.Addr{}, fmt.Errorf("both source and destination cannot be 'self'") } + var srcAddr, dstAddr netip.Addr + var err error + + // Parse the non-self address first so we know the family for self resolution. + if !srcSelf { + if srcAddr, err = parseAddr(src); err != nil { + return netip.Addr{}, netip.Addr{}, fmt.Errorf("invalid source IP: %w", err) + } + } + if !dstSelf { + if dstAddr, err = parseAddr(dst); err != nil { + return netip.Addr{}, netip.Addr{}, fmt.Errorf("invalid destination IP: %w", err) + } + } + + // Determine the peer address to pick the right self address. + peer := srcAddr + if srcSelf { + peer = dstAddr + } + + if srcSelf { + if srcAddr, err = selfAddr(engine, peer); err != nil { + return netip.Addr{}, netip.Addr{}, err + } + } + if dstSelf { + if dstAddr, err = selfAddr(engine, peer); err != nil { + return netip.Addr{}, netip.Addr{}, err + } + } + + return srcAddr, dstAddr, nil +} + +func selfAddr(engine *internal.Engine, peer netip.Addr) (netip.Addr, error) { + var addr netip.Addr + if peer.Is6() { + addr = engine.GetWgV6Addr() + } else { + addr = engine.GetWgAddr() + } + if !addr.IsValid() { + family := "IPv4" + if peer.Is6() { + family = "IPv6" + } + return netip.Addr{}, fmt.Errorf("no local %s overlay address configured", family) + } + return addr, nil +} + +func parseAddr(addr string) (netip.Addr, error) { a, err := netip.ParseAddr(addr) if err != nil { return netip.Addr{}, err } - return a.Unmap(), nil } diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index 5d69fd35c..01822ead6 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -3,6 +3,7 @@ package config import ( "context" "fmt" + "net/netip" "os" "path/filepath" "runtime" @@ -91,7 +92,8 @@ type Manager struct { // PeerSSHInfo represents a peer's SSH configuration information type PeerSSHInfo struct { Hostname string - IP string + IP netip.Addr + IPv6 netip.Addr FQDN string } @@ -210,8 +212,11 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { var hostPatterns []string - if peer.IP != "" { - hostPatterns = append(hostPatterns, peer.IP) + if peer.IP.IsValid() { + hostPatterns = append(hostPatterns, peer.IP.String()) + } + if peer.IPv6.IsValid() { + hostPatterns = append(hostPatterns, peer.IPv6.String()) } if peer.FQDN != "" { hostPatterns = append(hostPatterns, peer.FQDN) diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go index e7380c7f2..8e6be40a3 100644 --- a/client/ssh/config/manager_test.go +++ b/client/ssh/config/manager_test.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "net/netip" "os" "path/filepath" "runtime" @@ -28,12 +29,12 @@ func TestManager_SetupSSHClientConfig(t *testing.T) { peers := []PeerSSHInfo{ { Hostname: "peer1", - IP: "100.125.1.1", + IP: netip.MustParseAddr("100.125.1.1"), FQDN: "peer1.nb.internal", }, { Hostname: "peer2", - IP: "100.125.1.2", + IP: netip.MustParseAddr("100.125.1.2"), FQDN: "peer2.nb.internal", }, } @@ -101,7 +102,7 @@ func TestManager_PeerLimit(t *testing.T) { for i := 0; i < MaxPeersForSSHConfig+10; i++ { peers = append(peers, PeerSSHInfo{ Hostname: fmt.Sprintf("peer%d", i), - IP: fmt.Sprintf("100.125.1.%d", i%254+1), + IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)), FQDN: fmt.Sprintf("peer%d.nb.internal", i), }) } @@ -127,8 +128,8 @@ func TestManager_MatchHostFormat(t *testing.T) { } peers := []PeerSSHInfo{ - {Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"}, - {Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"}, + {Hostname: "peer1", IP: netip.MustParseAddr("100.125.1.1"), FQDN: "peer1.nb.internal"}, + {Hostname: "peer2", IP: netip.MustParseAddr("100.125.1.2"), FQDN: "peer2.nb.internal"}, } err = manager.SetupSSHClientConfig(peers) @@ -167,7 +168,7 @@ func TestManager_ForcedSSHConfig(t *testing.T) { for i := 0; i < MaxPeersForSSHConfig+10; i++ { peers = append(peers, PeerSSHInfo{ Hostname: fmt.Sprintf("peer%d", i), - IP: fmt.Sprintf("100.125.1.%d", i%254+1), + IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)), FQDN: fmt.Sprintf("peer%d.nb.internal", i), }) } diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index 59007f75c..eb659fe21 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -321,7 +321,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne return } - dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort) + dest := net.JoinHostPort(payload.DestAddr, strconv.Itoa(int(payload.DestPort))) log.Debugf("local port forwarding: %s", dest) backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User()) diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index e16ff5d46..f5ac66fca 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -56,12 +56,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) { server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool { logger := s.getRequestLogger(ctx) if !allowLocal { - logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort) + logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort)))) return false } if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil { - logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err) + logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(dstHost, strconv.Itoa(int(dstPort))), err) return false } @@ -71,12 +71,12 @@ func (s *Server) configurePortForwarding(server *ssh.Server) { server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool { logger := s.getRequestLogger(ctx) if !allowRemote { - logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort) + logger.Warnf("remote port forwarding denied for %s: disabled", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort)))) return false } if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil { - logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err) + logger.Warnf("remote port forwarding denied for %s: %v", net.JoinHostPort(bindHost, strconv.Itoa(int(bindPort))), err) return false } @@ -183,15 +183,16 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req * return false, nil } - key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))) + key := forwardKey(hostPort) if s.removeRemoteForwardListener(key) { - forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port) + forwardAddr := "-R " + hostPort s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr) - logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port) + logger.Infof("remote port forwarding cancelled: %s", hostPort) return true, nil } - logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port) + logger.Warnf("cancel-tcpip-forward failed: no listener found for %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))) return false, nil } @@ -201,7 +202,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h defer func() { if err := ln.Close(); err != nil { - logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err) + logger.Debugf("remote forward listener close error for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err) } }() @@ -230,7 +231,7 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h } go s.handleRemoteForwardConnection(ctx, result.conn, host, port) case <-ctx.Done(): - logger.Debugf("remote forward listener shutting down for %s:%d", host, port) + logger.Debugf("remote forward listener shutting down for %s", net.JoinHostPort(host, strconv.Itoa(int(port)))) return } } @@ -311,17 +312,17 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host) } - key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port)) + key := forwardKey(net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))) s.storeRemoteForwardListener(key, ln) - forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort) + forwardAddr := "-R " + net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort))) s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort) response := make([]byte, 4) binary.BigEndian.PutUint32(response, actualPort) - logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort) + logger.Infof("remote port forwarding established: %s", net.JoinHostPort(payload.Host, strconv.Itoa(int(actualPort)))) return true, response } @@ -351,7 +352,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr) if err != nil { - logger.Debugf("open forward channel for %s:%d: %v", host, port, err) + logger.Debugf("open forward channel for %s: %v", net.JoinHostPort(host, strconv.Itoa(int(port))), err) _ = conn.Close() return } diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 82d3b700f..de40d3091 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "strconv" "net/netip" "slices" "strings" @@ -137,10 +138,11 @@ type sessionState struct { } type Server struct { - sshServer *ssh.Server - listener net.Listener - mu sync.RWMutex - hostKeyPEM []byte + sshServer *ssh.Server + listener net.Listener + extraListeners []net.Listener + mu sync.RWMutex + hostKeyPEM []byte // sessions tracks active SSH sessions (shell, command, SFTP). // These are created when a client opens a session channel and requests shell/exec/subsystem. @@ -254,6 +256,35 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { return nil } +// AddListener starts serving SSH on an additional address (e.g. IPv6). +// Must be called after Start. +func (s *Server) AddListener(ctx context.Context, addr netip.AddrPort) error { + s.mu.Lock() + srv := s.sshServer + if srv == nil { + s.mu.Unlock() + return errors.New("SSH server is not running") + } + + ln, addrDesc, err := s.createListener(ctx, addr) + if err != nil { + s.mu.Unlock() + return fmt.Errorf("create listener: %w", err) + } + + s.extraListeners = append(s.extraListeners, ln) + s.mu.Unlock() + + log.Infof("SSH server also listening on %s", addrDesc) + + go func() { + if err := srv.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Errorf("SSH server error on %s: %v", addrDesc, err) + } + }() + return nil +} + func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) { if s.netstackNet != nil { ln, err := s.netstackNet.ListenTCPAddrPort(addr) @@ -291,6 +322,8 @@ func (s *Server) Stop() error { } s.sshServer = nil s.listener = nil + extraListeners := s.extraListeners + s.extraListeners = nil s.mu.Unlock() // Close outside the lock: session handlers need s.mu for unregisterSession. @@ -298,6 +331,12 @@ func (s *Server) Stop() error { log.Debugf("close SSH server: %v", err) } + for _, ln := range extraListeners { + if err := ln.Close(); err != nil { + log.Debugf("close extra SSH listener: %v", err) + } + } + s.mu.Lock() maps.Clear(s.sessions) maps.Clear(s.pendingAuthJWT) @@ -749,11 +788,10 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey { func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { s.mu.RLock() - netbirdNetwork := s.wgAddress.Network - localIP := s.wgAddress.IP + wgAddr := s.wgAddress s.mu.RUnlock() - if !netbirdNetwork.IsValid() || !localIP.IsValid() { + if !wgAddr.Network.IsValid() || !wgAddr.IP.IsValid() { return conn } @@ -769,14 +807,17 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP) return nil } + remoteIP = remoteIP.Unmap() // Block connections from our own IP (prevent local apps from connecting to ourselves) - if remoteIP == localIP { + if remoteIP == wgAddr.IP || wgAddr.IPv6.IsValid() && remoteIP == wgAddr.IPv6 { log.Warnf("SSH connection rejected from own IP %s", remoteIP) return nil } - if !netbirdNetwork.Contains(remoteIP) { + inV4 := wgAddr.Network.Contains(remoteIP) + inV6 := wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(remoteIP) + if !inV4 && !inV6 { log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP) return nil } @@ -876,20 +917,21 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, s.mu.RUnlock() if !allowLocal { - logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port) + logger.Warnf("local port forwarding denied for %s: disabled", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port)))) _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled") return } if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil { - logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) + logger.Warnf("local port forwarding denied for %s: %v", net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))), err) _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges") return } - forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port) + hostPort := net.JoinHostPort(payload.Host, strconv.Itoa(int(payload.Port))) + forwardAddr := "-L " + hostPort s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) - logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) + logger.Infof("local port forwarding: %s", hostPort) ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) } diff --git a/client/status/status.go b/client/status/status.go index 8c932bbab..11ed06c2d 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -60,6 +60,7 @@ type ConvertOptions struct { type PeerStateDetailOutput struct { FQDN string `json:"fqdn" yaml:"fqdn"` IP string `json:"netbirdIp" yaml:"netbirdIp"` + IPv6 string `json:"netbirdIpv6,omitempty" yaml:"netbirdIpv6,omitempty"` PubKey string `json:"publicKey" yaml:"publicKey"` Status string `json:"status" yaml:"status"` LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"` @@ -139,6 +140,7 @@ type OutputOverview struct { SignalState SignalStateOutput `json:"signal" yaml:"signal"` Relays RelayStateOutput `json:"relays" yaml:"relays"` IP string `json:"netbirdIp" yaml:"netbirdIp"` + IPv6 string `json:"netbirdIpv6,omitempty" yaml:"netbirdIpv6,omitempty"` PubKey string `json:"publicKey" yaml:"publicKey"` KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"` FQDN string `json:"fqdn" yaml:"fqdn"` @@ -182,6 +184,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO SignalState: signalOverview, Relays: relayOverview, IP: pbFullStatus.GetLocalPeerState().GetIP(), + IPv6: pbFullStatus.GetLocalPeerState().GetIpv6(), PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(), KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(), FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), @@ -317,6 +320,7 @@ func mapPeers( timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local() peerState := PeerStateDetailOutput{ IP: pbPeerState.GetIP(), + IPv6: pbPeerState.GetIpv6(), PubKey: pbPeerState.GetPubKey(), Status: pbPeerState.GetConnStatus(), LastStatusUpdate: timeLocal, @@ -417,6 +421,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS interfaceIP = "N/A" } + ipv6Line := "" + if o.IPv6 != "" { + ipv6Line = fmt.Sprintf("NetBird IPv6: %s\n", o.IPv6) + } + var relaysString string if showRelays { for _, relay := range o.Relays.Details { @@ -549,6 +558,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS "Nameservers: %s\n"+ "FQDN: %s\n"+ "NetBird IP: %s\n"+ + "%s"+ "Interface type: %s\n"+ "Quantum resistance: %s\n"+ "Lazy connection: %s\n"+ @@ -566,6 +576,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS dnsServersString, domain.Domain(o.FQDN).SafeString(), interfaceIP, + ipv6Line, interfaceTypeString, rosenpassEnabledStatus, lazyConnectionEnabledStatus, @@ -616,6 +627,7 @@ func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { } pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP + pbFullStatus.LocalPeerState.Ipv6 = fullStatus.LocalPeerState.IPv6 pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN @@ -628,6 +640,7 @@ func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { for _, peerState := range fullStatus.Peers { pbPeerState := &proto.PeerState{ IP: peerState.IP, + Ipv6: peerState.IPv6, PubKey: peerState.PubKey, ConnStatus: peerState.ConnStatus.String(), ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), @@ -733,9 +746,15 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo networks = strings.Join(peerState.Networks, ", ") } + ipv6Line := "" + if peerState.IPv6 != "" { + ipv6Line = fmt.Sprintf(" NetBird IPv6: %s\n", peerState.IPv6) + } + peerString := fmt.Sprintf( "\n %s:\n"+ " NetBird IP: %s\n"+ + "%s"+ " Public key: %s\n"+ " Status: %s\n"+ " -- detail --\n"+ @@ -751,6 +770,7 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Latency: %s\n", domain.Domain(peerState.FQDN).SafeString(), peerState.IP, + ipv6Line, peerState.PubKey, peerState.Status, peerState.ConnType, @@ -787,6 +807,9 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi if len(ipsFilter) > 0 { _, ok := ipsFilter[peerState.IP] + if !ok { + _, ok = ipsFilter[peerState.Ipv6] + } if !ok { ipEval = true } @@ -905,6 +928,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) { peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port) } + peer.IPv6 = a.AnonymizeIPString(peer.IPv6) peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) for i, route := range peer.Networks { @@ -929,6 +953,7 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) { overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error) overview.IP = a.AnonymizeIPString(overview.IP) + overview.IPv6 = a.AnonymizeIPString(overview.IPv6) for i, detail := range overview.Relays.Details { detail.URI = a.AnonymizeURI(detail.URI) detail.Error = a.AnonymizeString(detail.Error) diff --git a/client/status/status_test.go b/client/status/status_test.go index 7754eebae..0986bf0cd 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -32,6 +32,7 @@ var resp = &proto.StatusResponse{ Peers: []*proto.PeerState{ { IP: "192.168.178.101", + Ipv6: "fd00::1", PubKey: "Pubkey1", Fqdn: "peer-1.awesome-domain.com", ConnStatus: "Connected", @@ -90,6 +91,7 @@ var resp = &proto.StatusResponse{ }, LocalPeerState: &proto.LocalPeerState{ IP: "192.168.178.100/16", + Ipv6: "fd00::100", PubKey: "Some-Pub-Key", KernelInterface: true, Fqdn: "some-localhost.awesome-domain.com", @@ -130,6 +132,7 @@ var overview = OutputOverview{ Details: []PeerStateDetailOutput{ { IP: "192.168.178.101", + IPv6: "fd00::1", PubKey: "Pubkey1", FQDN: "peer-1.awesome-domain.com", Status: "Connected", @@ -204,6 +207,7 @@ var overview = OutputOverview{ }, }, IP: "192.168.178.100/16", + IPv6: "fd00::100", PubKey: "Some-Pub-Key", KernelInterface: true, FQDN: "some-localhost.awesome-domain.com", @@ -284,6 +288,7 @@ func TestParsingToJSON(t *testing.T) { { "fqdn": "peer-1.awesome-domain.com", "netbirdIp": "192.168.178.101", + "netbirdIpv6": "fd00::1", "publicKey": "Pubkey1", "status": "Connected", "lastStatusUpdate": "2001-01-01T01:01:01Z", @@ -361,6 +366,7 @@ func TestParsingToJSON(t *testing.T) { ] }, "netbirdIp": "192.168.178.100/16", + "netbirdIpv6": "fd00::100", "publicKey": "Some-Pub-Key", "usesKernelInterface": true, "fqdn": "some-localhost.awesome-domain.com", @@ -418,6 +424,7 @@ func TestParsingToYAML(t *testing.T) { details: - fqdn: peer-1.awesome-domain.com netbirdIp: 192.168.178.101 + netbirdIpv6: fd00::1 publicKey: Pubkey1 status: Connected lastStatusUpdate: 2001-01-01T01:01:01Z @@ -477,6 +484,7 @@ relays: available: false error: 'context: deadline exceeded' netbirdIp: 192.168.178.100/16 +netbirdIpv6: fd00::100 publicKey: Some-Pub-Key usesKernelInterface: true fqdn: some-localhost.awesome-domain.com @@ -523,6 +531,7 @@ func TestParsingToDetail(t *testing.T) { `Peers detail: peer-1.awesome-domain.com: NetBird IP: 192.168.178.101 + NetBird IPv6: fd00::1 Public key: Pubkey1 Status: Connected -- detail -- @@ -568,6 +577,7 @@ Nameservers: [1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout FQDN: some-localhost.awesome-domain.com NetBird IP: 192.168.178.100/16 +NetBird IPv6: fd00::100 Interface type: Kernel Quantum resistance: false Lazy connection: false @@ -592,6 +602,7 @@ Relays: 1/2 Available Nameservers: 1/2 Available FQDN: some-localhost.awesome-domain.com NetBird IP: 192.168.178.100/16 +NetBird IPv6: fd00::100 Interface type: Kernel Quantum resistance: false Lazy connection: false diff --git a/client/system/info.go b/client/system/info.go index 175d1f07f..477d5162b 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -69,6 +69,7 @@ type Info struct { DisableFirewall bool BlockLANAccess bool BlockInbound bool + DisableIPv6 bool LazyConnectionEnabled bool @@ -83,7 +84,7 @@ func (i *Info) SetFlags( rosenpassEnabled, rosenpassPermissive bool, serverSSHAllowed *bool, disableClientRoutes, disableServerRoutes, - disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool, + disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool, enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool, disableSSHAuth *bool, ) { @@ -99,6 +100,7 @@ func (i *Info) SetFlags( i.DisableFirewall = disableFirewall i.BlockLANAccess = blockLANAccess i.BlockInbound = blockInbound + i.DisableIPv6 = disableIPv6 i.LazyConnectionEnabled = lazyConnectionEnabled diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 28f98ae59..c2129c7a2 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -279,6 +279,7 @@ type serviceClient struct { sDisableDNS *widget.Check sDisableClientRoutes *widget.Check sDisableServerRoutes *widget.Check + sDisableIPv6 *widget.Check sBlockLANAccess *widget.Check sEnableSSHRoot *widget.Check sEnableSSHSFTP *widget.Check @@ -299,6 +300,7 @@ type serviceClient struct { disableDNS bool disableClientRoutes bool disableServerRoutes bool + disableIPv6 bool blockLANAccess bool enableSSHRoot bool enableSSHSFTP bool @@ -468,6 +470,7 @@ func (s *serviceClient) showSettingsUI() { s.sDisableDNS = widget.NewCheck("Keeps system DNS settings unchanged", nil) s.sDisableClientRoutes = widget.NewCheck("This peer won't route traffic to other peers", nil) s.sDisableServerRoutes = widget.NewCheck("This peer won't act as router for others", nil) + s.sDisableIPv6 = widget.NewCheck("Disable IPv6 overlay addressing", nil) s.sBlockLANAccess = widget.NewCheck("Blocks local network access when used as exit node", nil) s.sEnableSSHRoot = widget.NewCheck("Enable SSH Root Login", nil) s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil) @@ -585,6 +588,7 @@ func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool s.disableDNS != s.sDisableDNS.Checked || s.disableClientRoutes != s.sDisableClientRoutes.Checked || s.disableServerRoutes != s.sDisableServerRoutes.Checked || + s.disableIPv6 != s.sDisableIPv6.Checked || s.blockLANAccess != s.sBlockLANAccess.Checked || s.hasSSHChanges() } @@ -637,6 +641,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) ( req.DisableDns = &s.sDisableDNS.Checked req.DisableClientRoutes = &s.sDisableClientRoutes.Checked req.DisableServerRoutes = &s.sDisableServerRoutes.Checked + req.DisableIpv6 = &s.sDisableIPv6.Checked req.BlockLanAccess = &s.sBlockLANAccess.Checked req.EnableSSHRoot = &s.sEnableSSHRoot.Checked @@ -676,24 +681,23 @@ func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error { return fmt.Errorf("set config: %w", err) } - // Reconnect if connected to apply the new settings + // Reconnect if connected to apply the new settings. + // Use a background context so the reconnect outlives the settings window. go func() { - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) + log.Errorf("failed to get service status: %v", err) return } if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) - if err != nil { - log.Errorf("down service: %v", err) + if _, err = conn.Down(ctx, &proto.DownRequest{}); err != nil { + log.Errorf("failed to stop service: %v", err) } - - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - return + // TODO: wait for the service to be idle before calling Up, or use a fresh connection + if _, err = conn.Up(ctx, &proto.UpRequest{}); err != nil { + log.Errorf("failed to start service: %v", err) } } }() @@ -730,6 +734,7 @@ func (s *serviceClient) getNetworkForm() *widget.Form { {Text: "Disable DNS", Widget: s.sDisableDNS}, {Text: "Disable Client Routes", Widget: s.sDisableClientRoutes}, {Text: "Disable Server Routes", Widget: s.sDisableServerRoutes}, + {Text: "Disable IPv6", Widget: s.sDisableIPv6}, {Text: "Disable LAN Access", Widget: s.sBlockLANAccess}, }, } @@ -1327,6 +1332,7 @@ func (s *serviceClient) getSrvConfig() { s.disableDNS = cfg.DisableDNS s.disableClientRoutes = cfg.DisableClientRoutes s.disableServerRoutes = cfg.DisableServerRoutes + s.disableIPv6 = cfg.DisableIPv6 s.blockLANAccess = cfg.BlockLANAccess if cfg.EnableSSHRoot != nil { @@ -1367,6 +1373,7 @@ func (s *serviceClient) getSrvConfig() { s.sDisableDNS.SetChecked(cfg.DisableDNS) s.sDisableClientRoutes.SetChecked(cfg.DisableClientRoutes) s.sDisableServerRoutes.SetChecked(cfg.DisableServerRoutes) + s.sDisableIPv6.SetChecked(cfg.DisableIPv6) s.sBlockLANAccess.SetChecked(cfg.BlockLANAccess) if cfg.EnableSSHRoot != nil { s.sEnableSSHRoot.SetChecked(*cfg.EnableSSHRoot) @@ -1454,6 +1461,7 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { config.DisableDNS = cfg.DisableDns config.DisableClientRoutes = cfg.DisableClientRoutes config.DisableServerRoutes = cfg.DisableServerRoutes + config.DisableIPv6 = cfg.DisableIpv6 config.BlockLANAccess = cfg.BlockLanAccess config.EnableSSHRoot = &cfg.EnableSSHRoot diff --git a/client/ui/event/event.go b/client/ui/event/event.go index ea968f60a..3b43fdc7f 100644 --- a/client/ui/event/event.go +++ b/client/ui/event/event.go @@ -112,7 +112,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) { handlers := slices.Clone(e.handlers) e.mu.Unlock() - if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) { + if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) { title := e.getEventTitle(event) body := event.UserMessage id := event.Metadata["id"] @@ -133,6 +133,14 @@ func (e *Manager) AddHandler(handler Handler) { e.handlers = append(e.handlers, handler) } +// isV6DefaultRoutePartner reports whether the event is the IPv6 half of a +// paired v4/v6 default-route event. Management always pairs ::/0 with 0.0.0.0/0 +// for exit nodes, so the v4 partner already drives the user-facing toast and +// the v6 one is suppressed to avoid a duplicate notification. +func isV6DefaultRoutePartner(event *proto.SystemEvent) bool { + return event.Category == proto.SystemEvent_NETWORK && event.Metadata["network"] == "::/0" +} + func (e *Manager) getEventTitle(event *proto.SystemEvent) string { var prefix string switch event.Severity { diff --git a/client/ui/network.go b/client/ui/network.go index 571e871bb..1619f78a2 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -192,10 +192,14 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { return filteredRoutes } +func isDefaultRoute(routeRange string) bool { + return routeRange == "0.0.0.0/0" || routeRange == "::/0" +} + func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { var filteredRoutes []*proto.Network for _, route := range routes { - if route.Range == "0.0.0.0/0" { + if isDefaultRoute(route.Range) { filteredRoutes = append(filteredRoutes, route) } } @@ -499,7 +503,7 @@ func (s *serviceClient) getExitNodes(conn proto.DaemonServiceClient) ([]*proto.N var exitNodes []*proto.Network for _, network := range resp.Routes { - if network.Range == "0.0.0.0/0" { + if isDefaultRoute(network.Range) { exitNodes = append(exitNodes, network) } } diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index cb512f132..066fe043b 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -5,6 +5,8 @@ package main import ( "context" "fmt" + "net" + "strconv" "sync" "syscall/js" "time" @@ -83,6 +85,10 @@ func parseClientOptions(jsOptions js.Value) (netbird.Options, error) { options.DeviceName = deviceName.String() } + if disableIPv6 := jsOptions.Get("disableIPv6"); !disableIPv6.IsNull() && !disableIPv6.IsUndefined() { + options.DisableIPv6 = disableIPv6.Bool() + } + return options, nil } @@ -163,39 +169,58 @@ func createSSHMethod(client *netbird.Client) js.Func { }) } - var jwtToken string - if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { - jwtToken = args[3].String() - } + jwtToken, ipVersion := parseSSHOptions(args) return createPromise(func(resolve, reject js.Value) { - sshClient := ssh.NewClient(client) - - if err := sshClient.Connect(host, port, username, jwtToken); err != nil { + jsInterface, err := connectSSH(client, host, port, username, jwtToken, ipVersion) + if err != nil { reject.Invoke(err.Error()) return } - - if err := sshClient.StartSession(80, 24); err != nil { - if closeErr := sshClient.Close(); closeErr != nil { - log.Errorf("Error closing SSH client: %v", closeErr) - } - reject.Invoke(err.Error()) - return - } - - jsInterface := ssh.CreateJSInterface(sshClient) resolve.Invoke(jsInterface) }) }) } -func performPing(client *netbird.Client, hostname string) { +func parseSSHOptions(args []js.Value) (jwtToken string, ipVersion int) { + if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() { + jwtToken = args[3].String() + } + if len(args) > 4 { + ipVersion = jsIPVersion(args[4]) + } + return +} + +func connectSSH(client *netbird.Client, host string, port int, username, jwtToken string, ipVersion int) (js.Value, error) { + sshClient := ssh.NewClient(client) + + if err := sshClient.Connect(host, port, username, jwtToken, ipVersion); err != nil { + return js.Undefined(), err + } + + if err := sshClient.StartSession(80, 24); err != nil { + if closeErr := sshClient.Close(); closeErr != nil { + log.Errorf("Error closing SSH client: %v", closeErr) + } + return js.Undefined(), err + } + + return ssh.CreateJSInterface(sshClient), nil +} + +func performPing(client *netbird.Client, hostname string, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() + // Default to ping4 to avoid dual-stack ICMP endpoint issues in wireguard-go netstack. + network := "ping4" + if ipVersion == 6 { + network = "ping6" + } + start := time.Now() - conn, err := client.Dial(ctx, "ping", hostname) + conn, err := client.Dial(ctx, network, hostname) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err)) return @@ -222,27 +247,39 @@ func performPing(client *netbird.Client, hostname string) { } latency := time.Since(start) - js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())) + remote := conn.RemoteAddr().String() + msg := fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()) + if remote != hostname { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } -func performPingTCP(client *netbird.Client, hostname string, port int) { +func performPingTCP(client *netbird.Client, hostname string, port, ipVersion int) { ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() - address := fmt.Sprintf("%s:%d", hostname, port) + network := ipVersionNetwork("tcp", ipVersion) + + address := net.JoinHostPort(hostname, fmt.Sprintf("%d", port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err)) return } latency := time.Since(start) + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { log.Debugf("failed to close TCP connection: %v", err) } - js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())) + msg := fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()) + if remote != address { + msg += fmt.Sprintf(" (via %s)", remote) + } + js.Global().Get("console").Call("log", msg) } // createPingMethod creates the ping method @@ -259,8 +296,12 @@ func createPingMethod(client *netbird.Client) js.Func { } hostname := args[0].String() + var ipVersion int + if len(args) > 1 { + ipVersion = jsIPVersion(args[1]) + } return createPromise(func(resolve, reject js.Value) { - performPing(client, hostname) + performPing(client, hostname, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -287,8 +328,12 @@ func createPingTCPMethod(client *netbird.Client) js.Func { hostname := args[0].String() port := args[1].Int() + var ipVersion int + if len(args) > 2 { + ipVersion = jsIPVersion(args[2]) + } return createPromise(func(resolve, reject js.Value) { - performPingTCP(client, hostname, port) + performPingTCP(client, hostname, port, ipVersion) resolve.Invoke(js.Undefined()) }) }) @@ -461,6 +506,31 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func { }) } +// ipVersionNetwork appends "4" or "6" to a base network string (e.g. "tcp" -> "tcp4"). +func ipVersionNetwork(base string, ipVersion int) string { + switch ipVersion { + case 4: + return base + "4" + case 6: + return base + "6" + default: + return base + } +} + +// jsIPVersion extracts an IP version (4 or 6) from a JS string or number. +func jsIPVersion(v js.Value) int { + switch v.Type() { + case js.TypeNumber: + return v.Int() + case js.TypeString: + n, _ := strconv.Atoi(v.String()) + return n + default: + return 0 + } +} + // createStartCaptureMethod creates the programmable packet capture method. // Returns a JS interface with onpacket callback and stop() method. // diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 16bf63bb9..6c36fdec6 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -82,7 +82,7 @@ func NewRDCleanPathProxy(client interface { // CreateProxy creates a new proxy endpoint for the given destination func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { - destination := fmt.Sprintf("%s:%s", hostname, port) + destination := net.JoinHostPort(hostname, port) return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any { resolve := args[0] diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go index 568437e56..9cfe65266 100644 --- a/client/wasm/internal/ssh/client.go +++ b/client/wasm/internal/ssh/client.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "io" + "net" "sync" "time" @@ -45,9 +46,10 @@ func NewClient(nbClient *netbird.Client) *Client { } } -// Connect establishes an SSH connection through NetBird network -func (c *Client) Connect(host string, port int, username, jwtToken string) error { - addr := fmt.Sprintf("%s:%d", host, port) +// Connect establishes an SSH connection through NetBird network. +// ipVersion may be 4, 6, or 0 for automatic selection. +func (c *Client) Connect(host string, port int, username, jwtToken string, ipVersion int) error { + addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) logrus.Infof("SSH: Connecting to %s as %s", addr, username) authMethods, err := c.getAuthMethods(jwtToken) @@ -62,10 +64,18 @@ func (c *Client) Connect(host string, port int, username, jwtToken string) error Timeout: sshDialTimeout, } + network := "tcp" + switch ipVersion { + case 4: + network = "tcp4" + case 6: + network = "tcp6" + } + ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) defer cancel() - conn, err := c.nbClient.Dial(ctx, "tcp", addr) + conn, err := c.nbClient.Dial(ctx, network, addr) if err != nil { return fmt.Errorf("dial %s: %w", addr, err) } diff --git a/combined/cmd/config.go b/combined/cmd/config.go index ce4df8394..9959f7a56 100644 --- a/combined/cmd/config.go +++ b/combined/cmd/config.go @@ -380,7 +380,7 @@ func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost, // Auto-configure local STUN servers for all ports for _, port := range c.Server.StunPorts { c.Management.Stuns = append(c.Management.Stuns, HostConfig{ - URI: fmt.Sprintf("stun:%s:%d", exposedHost, port), + URI: "stun:" + net.JoinHostPort(strings.Trim(exposedHost, "[]"), fmt.Sprintf("%d", port)), }) } } diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go index fc91b8616..3485d51fe 100644 --- a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -2,7 +2,7 @@ package manager import ( "context" - "net" + "net/netip" "testing" "time" @@ -56,7 +56,8 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor Key: "test-key", DNSLabel: "test-peer", Name: "test-peer", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, }, diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 0fb5f46ff..d03a8dc82 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand/v2" + "net" "net/http" "os" "slices" @@ -1103,7 +1104,7 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s serviceURL := "https://" + svc.Domain if service.IsL4Protocol(svc.Mode) { - serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort) + serviceURL = fmt.Sprintf("%s://%s", svc.Mode, net.JoinHostPort(svc.Domain, strconv.Itoa(int(svc.ListenPort)))) } return &service.ExposeServiceResponse{ @@ -1272,7 +1273,7 @@ func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]a return meta } meta["peer_name"] = peer.Name - if peer.IP != nil { + if peer.IP.IsValid() { meta["peer_ip"] = peer.IP.String() } return meta diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index e9403849c..46e79f1e5 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -3,7 +3,7 @@ package manager import ( "context" "errors" - "net" + "net/netip" "testing" "time" @@ -405,7 +405,8 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { testPeer := &nbpeer.Peer{ ID: ownerPeerID, Name: "test-peer", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), } newEphemeralService := func() *rpservice.Service { @@ -682,7 +683,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { Key: "test-key", DNSLabel: "test-peer", Name: "test-peer", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, }, @@ -751,7 +753,8 @@ func Test_validateExposePermission(t *testing.T) { Key: "other-key", DNSLabel: "other-peer", Name: "other-peer", - IP: net.ParseIP("100.64.0.2"), + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00::2"), Status: &nbpeer.PeerStatus{LastSeen: time.Now()}, Meta: nbpeer.PeerSystemMeta{Hostname: "other-peer"}, }) diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index ef417d3cf..12402b420 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -3,12 +3,15 @@ package grpc import ( "context" "fmt" + "net/netip" "net/url" "strings" log "github.com/sirupsen/logrus" + goproto "google.golang.org/protobuf/proto" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" @@ -17,8 +20,9 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/route" + nbroute "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/netiputil" "github.com/netbirdio/netbird/shared/sshauth" ) @@ -100,7 +104,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig) } - return &proto.PeerConfig{ + peerConfig := &proto.PeerConfig{ Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), SshConfig: sshConfig, Fqdn: fqdn, @@ -111,9 +115,25 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set AlwaysUpdate: settings.AutoUpdateAlways, }, } + + if peer.SupportsIPv6() && peer.IPv6.IsValid() && network.NetV6.IP != nil { + ones, _ := network.NetV6.Mask.Size() + v6Prefix := netip.PrefixFrom(peer.IPv6.Unmap(), ones) + if b, err := netiputil.EncodePrefix(v6Prefix); err == nil { + peerConfig.AddressV6 = b + } + } + + return peerConfig } func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { + // IPv6 data in AllowedIPs and SourcePrefixes wildcard expansion depends on + // whether the target peer supports IPv6. Routes and firewall rules are already + // filtered at the source (network map builder). + includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid() + useSourcePrefixes := peer.SupportsSourcePrefixes() + response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), NetworkMap: &proto.NetworkMap{ @@ -132,15 +152,15 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb response.NetworkMap.PeerConfig = response.PeerConfig remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) + remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6) response.RemotePeers = remotePeers response.NetworkMap.RemotePeers = remotePeers response.RemotePeersIsEmpty = len(remotePeers) == 0 response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty - response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) + response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6) - firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes) response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 @@ -195,11 +215,15 @@ func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]m return hashedUsers, machineUsers } -func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { +func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig { for _, rPeer := range peers { + allowedIPs := []string{rPeer.IP.String() + "/32"} + if includeIPv6 && rPeer.IPv6.IsValid() { + allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128") + } dst = append(dst, &proto.RemotePeerConfig{ WgPubKey: rPeer.Key, - AllowedIps: []string{rPeer.IP.String() + "/32"}, + AllowedIps: allowedIPs, SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, Fqdn: rPeer.FQDN(dnsName), AgentVersion: rPeer.Meta.WtVersion, @@ -253,7 +277,7 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { } } -func toProtocolRoutes(routes []*route.Route) []*proto.Route { +func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route { protoRoutes := make([]*proto.Route, 0, len(routes)) for _, r := range routes { protoRoutes = append(protoRoutes, toProtocolRoute(r)) @@ -261,7 +285,7 @@ func toProtocolRoutes(routes []*route.Route) []*proto.Route { return protoRoutes } -func toProtocolRoute(route *route.Route) *proto.Route { +func toProtocolRoute(route *nbroute.Route) *proto.Route { return &proto.Route{ ID: string(route.ID), NetID: string(route.NetID), @@ -277,29 +301,70 @@ func toProtocolRoute(route *route.Route) *proto.Route { } // toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) +// When useSourcePrefixes is true, the compact SourcePrefixes field is populated +// alongside the deprecated PeerIP for forward compatibility. +// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes +// when includeIPv6 is true. +func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, 0, len(rules)) for i := range rules { rule := rules[i] fwRule := &proto.FirewallRule{ PolicyID: []byte(rule.PolicyID), - PeerIP: rule.PeerIP, + PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility Direction: getProtoDirection(rule.Direction), Action: getProtoAction(rule.Action), Protocol: getProtoProtocol(rule.Protocol), Port: rule.Port, } + if useSourcePrefixes && rule.PeerIP != "" { + result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...) + } + if shouldUsePortRange(fwRule) { fwRule.PortInfo = rule.PortRange.ToProto() } - result[i] = fwRule + result = append(result, fwRule) } return result } + +// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any +// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified). +func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule { + addr, err := netip.ParseAddr(rule.PeerIP) + if err != nil { + return nil + } + + if !addr.IsUnspecified() { + fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())} + return nil + } + + // IPv4Unspecified/0 is always valid, error is impossible. + v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0)) + fwRule.SourcePrefixes = [][]byte{v4Wildcard} + + if !includeIPv6 { + return nil + } + + v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule) + v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility + // IPv6Unspecified/0 is always valid, error is impossible. + v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0)) + v6Rule.SourcePrefixes = [][]byte{v6Wildcard} + if shouldUsePortRange(v6Rule) { + v6Rule.PortInfo = rule.PortRange.ToProto() + } + return []*proto.FirewallRule{v6Rule} +} + // getProtoDirection converts the direction to proto.RuleDirection. func getProtoDirection(direction int) proto.RuleDirection { if direction == types.FirewallRuleDirectionOUT { diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 0c1611e7f..70024bac6 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -680,11 +680,21 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee BlockLANAccess: meta.GetFlags().GetBlockLANAccess(), BlockInbound: meta.GetFlags().GetBlockInbound(), LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(), + DisableIPv6: meta.GetFlags().GetDisableIPv6(), }, - Files: files, + Files: files, + Capabilities: capabilitiesToInt32(meta.GetCapabilities()), } } +func capabilitiesToInt32(caps []proto.PeerCapability) []int32 { + result := make([]int32, len(caps)) + for i, c := range caps { + result[i] = int32(c) + } + return result +} + func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { diff --git a/management/server/account.go b/management/server/account.go index 4b71ab486..45b99839f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -329,6 +329,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco updateAccountPeers = true } + if ipv6SettingsChanged(oldSettings, newSettings) { + if err = am.updatePeerIPv6Addresses(ctx, transaction, accountID, newSettings); err != nil { + return err + } + updateAccountPeers = true + } + if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || oldSettings.DNSDomain != newSettings.DNSDomain || @@ -338,7 +345,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled { - groupsUpdated, groupChangesAffectPeers, err = propagateUserGroupMemberships(ctx, transaction, accountID) + groupsUpdated, groupChangesAffectPeers, err = am.propagateUserGroupMemberships(ctx, transaction, accountID) if err != nil { return err } @@ -393,6 +400,22 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) } + oldIPv6On := len(oldSettings.IPv6EnabledGroups) > 0 + newIPv6On := len(newSettings.IPv6EnabledGroups) > 0 + if oldIPv6On != newIPv6On { + if newIPv6On { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountIPv6Enabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountIPv6Disabled, nil) + } + } + if oldSettings.NetworkRangeV6 != newSettings.NetworkRangeV6 { + eventMeta := map[string]any{ + "old_network_range_v6": oldSettings.NetworkRangeV6.String(), + "new_network_range_v6": newSettings.NetworkRangeV6.String(), + } + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) + } if reloadReverseProxy { if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil { log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err) @@ -406,6 +429,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return newSettings, nil } +func ipv6SettingsChanged(old, updated *types.Settings) bool { + if old.NetworkRangeV6 != updated.NetworkRangeV6 { + return true + } + oldGroups := slices.Clone(old.IPv6EnabledGroups) + newGroups := slices.Clone(updated.IPv6EnabledGroups) + slices.Sort(oldGroups) + slices.Sort(newGroups) + return !slices.Equal(oldGroups, newGroups) +} + func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { @@ -432,9 +466,38 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra } } + if err := validateIPv6EnabledGroups(ctx, transaction, accountID, newSettings.IPv6EnabledGroups); err != nil { + return err + } + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID) } +// validateIPv6EnabledGroups checks that all referenced IPv6-enabled group IDs exist in the account. +func validateIPv6EnabledGroups(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("get groups for IPv6 validation: %w", err) + } + + existing := make(map[string]struct{}, len(groups)) + for _, g := range groups { + existing[g.ID] = struct{}{} + } + + for _, gid := range groupIDs { + if _, ok := existing[gid]; !ok { + return status.Errorf(status.InvalidArgument, "IPv6 enabled group %s does not exist", gid) + } + } + + return nil +} + func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled { if newSettings.RoutingPeerDNSResolutionEnabled { @@ -739,37 +802,8 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) } - for _, otherUser := range account.Users { - if otherUser.Id == userID { - continue - } - - if otherUser.IsServiceUser { - err = am.deleteServiceUser(ctx, accountID, userID, otherUser) - if err != nil { - return err - } - continue - } - - userInfo, ok := userInfosMap[otherUser.Id] - if !ok { - return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id) - } - - _, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, userInfo) - if deleteUserErr != nil { - return deleteUserErr - } - } - - userInfo, ok := userInfosMap[userID] - if ok { - _, err = am.deleteRegularUser(ctx, accountID, userID, userInfo) - if err != nil { - log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err) - return err - } + if err = am.deleteAccountUsers(ctx, accountID, userID, account.Users, userInfosMap); err != nil { + return err } err = am.Store.DeleteAccount(ctx, account) @@ -787,6 +821,40 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } +func (am *DefaultAccountManager) deleteAccountUsers(ctx context.Context, accountID, initiatorUserID string, users map[string]*types.User, userInfosMap map[string]*types.UserInfo) error { + for _, otherUser := range users { + if otherUser.Id == initiatorUserID { + continue + } + + if otherUser.IsServiceUser { + if err := am.deleteServiceUser(ctx, accountID, initiatorUserID, otherUser); err != nil { + return err + } + continue + } + + userInfo, ok := userInfosMap[otherUser.Id] + if !ok { + return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id) + } + + if _, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo); err != nil { + return err + } + } + + userInfo, ok := userInfosMap[initiatorUserID] + if ok { + if _, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo); err != nil { + log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", initiatorUserID, err) + return err + } + } + + return nil +} + // AccountExists checks if an account exists. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID) @@ -1528,6 +1596,11 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } } + allGroupChanges := slices.Concat(addNewGroups, removeOldGroups) + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, userAuth.AccountId, allGroupChanges); err != nil { + return fmt.Errorf("reconcile IPv6 for group changes: %w", err) + } + if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } @@ -1913,6 +1986,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain, email, nam if err := acc.AddAllGroup(disableDefaultPolicy); err != nil { log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) } + + if allGroup, err := acc.GetGroupAll(); err == nil { + acc.Settings.IPv6EnabledGroups = []string{allGroup.ID} + } + return acc } @@ -2019,6 +2097,10 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain") } + if allGroup, err := newAccount.GetGroupAll(); err == nil { + newAccount.Settings.IPv6EnabledGroups = []string{allGroup.ID} + } + if err := am.Store.SaveAccount(ctx, newAccount); err != nil { log.WithContext(ctx).WithFields(log.Fields{ "accountId": newAccount.Id, @@ -2080,7 +2162,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc // propagateUserGroupMemberships propagates all account users' group memberships to their peers. // Returns true if any groups were modified, true if those updates affect peers and an error. -func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { +func (am *DefaultAccountManager) propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) if err != nil { return false, false, err @@ -2102,29 +2184,13 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, } } - updatedGroups := []string{} - for _, user := range users { - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id) - if err != nil { - return false, false, err - } + updatedGroups, err := propagateAutoGroupsForUsers(ctx, transaction, accountID, users, accountGroupPeers) + if err != nil { + return false, false, err + } - for _, peer := range userPeers { - for _, groupID := range user.AutoGroups { - if _, exists := accountGroupPeers[groupID]; !exists { - // we do not wanna create the groups here - log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID) - continue - } - if _, exists := accountGroupPeers[groupID][peer.ID]; exists { - continue - } - if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { - return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err) - } - updatedGroups = append(updatedGroups, groupID) - } - } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, updatedGroups); err != nil { + return false, false, fmt.Errorf("reconcile IPv6 for group changes: %w", err) } peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups) @@ -2135,6 +2201,35 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, return len(updatedGroups) > 0, peersAffected, nil } +// propagateAutoGroupsForUsers adds each user's peers to their AutoGroups where not already present. +// Returns the list of group IDs that were modified. +func propagateAutoGroupsForUsers(ctx context.Context, transaction store.Store, accountID string, users []*types.User, accountGroupPeers map[string]map[string]struct{}) ([]string, error) { + var updatedGroups []string + for _, user := range users { + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id) + if err != nil { + return nil, err + } + + for _, peer := range userPeers { + for _, groupID := range user.AutoGroups { + if _, exists := accountGroupPeers[groupID]; !exists { + log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID) + continue + } + if _, exists := accountGroupPeers[groupID][peer.ID]; exists { + continue + } + if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { + return nil, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err) + } + updatedGroups = append(updatedGroups, groupID) + } + } + } + return updatedGroups, nil +} + // reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error { if !newNetworkRange.IsValid() { @@ -2156,10 +2251,10 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t return err } - var takenIPs []net.IP + var takenIPs []netip.Addr for _, peer := range peers { - newIP, err := types.AllocatePeerIP(newIPNet, takenIPs) + newIP, err := types.AllocatePeerIP(newNetworkRange, takenIPs) if err != nil { return status.Errorf(status.Internal, "allocate IP for peer %s: %v", peer.ID, err) } @@ -2183,13 +2278,199 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t return nil } +// updatePeerIPv6Addresses assigns or removes IPv6 addresses for all peers +// based on the current IPv6 settings. When IPv6 is enabled, peers without a +// v6 address get one allocated. When disabled, all v6 addresses are cleared. +// When the v6 range changes, all v6 addresses are reallocated. +func (am *DefaultAccountManager) checkIPv6Collision(ctx context.Context, transaction store.Store, accountID, peerID string, newIPv6 netip.Addr) error { + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") + if err != nil { + return fmt.Errorf("get peers: %w", err) + } + for _, p := range peers { + if p.ID != peerID && p.IPv6.IsValid() && p.IPv6 == newIPv6 { + return status.Errorf(status.InvalidArgument, "IPv6 %s is already assigned to peer %s", newIPv6, p.Name) + } + } + return nil +} + +func (am *DefaultAccountManager) updatePeerIPv6Addresses(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings) error { + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthUpdate, accountID, "", "") + if err != nil { + return fmt.Errorf("get peers: %w", err) + } + + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("get network: %w", err) + } + + if err := am.ensureIPv6Subnet(ctx, transaction, accountID, settings, network); err != nil { + return err + } + + allowedPeers, err := am.buildIPv6AllowedPeers(ctx, transaction, accountID, settings) + if err != nil { + return err + } + + v6Prefix, err := netip.ParsePrefix(network.NetV6.String()) + if err != nil { + return fmt.Errorf("parse IPv6 prefix: %w", err) + } + + if err := am.assignPeerIPv6Addresses(ctx, transaction, accountID, peers, network, allowedPeers, v6Prefix); err != nil { + return err + } + + log.WithContext(ctx).Infof("updated IPv6 addresses for %d peers in account %s (groups=%d)", + len(peers), accountID, len(settings.IPv6EnabledGroups)) + + return nil +} + +// reconcileIPv6ForGroupChanges checks whether the given group IDs overlap with +// the account's IPv6EnabledGroups. If they do, it runs a full IPv6 address +// reconciliation so that peers gaining or losing membership in an IPv6-enabled +// group get their addresses assigned or removed. +func (am *DefaultAccountManager) reconcileIPv6ForGroupChanges(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) error { + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("get account settings: %w", err) + } + + if len(settings.IPv6EnabledGroups) == 0 { + return nil + } + + enabledSet := make(map[string]struct{}, len(settings.IPv6EnabledGroups)) + for _, gid := range settings.IPv6EnabledGroups { + enabledSet[gid] = struct{}{} + } + + affected := false + for _, gid := range groupIDs { + if _, ok := enabledSet[gid]; ok { + affected = true + break + } + } + + if !affected { + return nil + } + + return am.updatePeerIPv6Addresses(ctx, transaction, accountID, settings) +} + +func (am *DefaultAccountManager) ensureIPv6Subnet(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings, network *types.Network) error { + if settings.NetworkRangeV6.IsValid() { + network.NetV6 = net.IPNet{ + IP: settings.NetworkRangeV6.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(settings.NetworkRangeV6.Bits(), 128), + } + return transaction.UpdateAccountNetworkV6(ctx, accountID, network.NetV6) + } + if network.NetV6.IP == nil { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + network.NetV6 = types.AllocateIPv6Subnet(r) + + // Sync settings to match the allocated subnet so SaveAccountSettings persists it. + ones, _ := network.NetV6.Mask.Size() + addr, _ := netip.AddrFromSlice(network.NetV6.IP) + settings.NetworkRangeV6 = netip.PrefixFrom(addr.Unmap(), ones) + + return transaction.UpdateAccountNetworkV6(ctx, accountID, network.NetV6) + } + return nil +} + +func (am *DefaultAccountManager) assignPeerIPv6Addresses( + ctx context.Context, transaction store.Store, accountID string, + peers []*nbpeer.Peer, network *types.Network, + allowedPeers map[string]struct{}, v6Prefix netip.Prefix, +) error { + takenV6 := make(map[netip.Addr]struct{}) + for _, peer := range peers { + if _, ok := allowedPeers[peer.ID]; ok && peer.IPv6.IsValid() && network.NetV6.Contains(peer.IPv6.AsSlice()) { + takenV6[peer.IPv6] = struct{}{} + } + } + + for _, peer := range peers { + _, allowed := allowedPeers[peer.ID] + oldIPv6 := peer.IPv6 + + if !allowed { + peer.IPv6 = netip.Addr{} + } else if !peer.IPv6.IsValid() || !network.NetV6.Contains(peer.IPv6.AsSlice()) { + newIP, err := allocateIPv6WithRetry(v6Prefix, takenV6, peer.ID) + if err != nil { + return err + } + peer.IPv6 = newIP + } + + if peer.IPv6 == oldIPv6 { + continue + } + + if err := transaction.SavePeer(ctx, accountID, peer); err != nil { + return fmt.Errorf("save peer %s: %w", peer.ID, err) + } + } + return nil +} + +func allocateIPv6WithRetry(prefix netip.Prefix, taken map[netip.Addr]struct{}, peerID string) (netip.Addr, error) { + for attempts := 0; attempts < 10; attempts++ { + newIP, err := types.AllocateRandomPeerIPv6(prefix) + if err != nil { + return netip.Addr{}, fmt.Errorf("allocate v6 for peer %s: %w", peerID, err) + } + if _, ok := taken[newIP]; !ok { + taken[newIP] = struct{}{} + return newIP, nil + } + } + return netip.Addr{}, fmt.Errorf("allocate v6 for peer %s: exhausted 10 attempts", peerID) +} + +func (am *DefaultAccountManager) buildIPv6AllowedPeers(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings) (map[string]struct{}, error) { + if len(settings.IPv6EnabledGroups) == 0 { + return make(map[string]struct{}), nil + } + + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return nil, fmt.Errorf("get groups: %w", err) + } + + enabledSet := make(map[string]struct{}, len(settings.IPv6EnabledGroups)) + for _, gid := range settings.IPv6EnabledGroups { + enabledSet[gid] = struct{}{} + } + + allowedPeers := make(map[string]struct{}) + for _, group := range groups { + if _, ok := enabledSet[group.ID]; !ok { + continue + } + for _, peerID := range group.Peers { + allowedPeers[peerID] = struct{}{} + } + } + return allowedPeers, nil +} + func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, peers []*nbpeer.Peer, peerID string, newIP netip.Addr) error { if !account.Network.Net.Contains(newIP.AsSlice()) { return status.Errorf(status.InvalidArgument, "IP %s is not within the account network range %s", newIP.String(), account.Network.Net.String()) } for _, peer := range peers { - if peer.ID != peerID && peer.IP.Equal(newIP.AsSlice()) { + if peer.ID != peerID && peer.IP == newIP { return status.Errorf(status.InvalidArgument, "IP %s is already assigned to peer %s", newIP.String(), peer.ID) } } @@ -2236,7 +2517,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context, return fmt.Errorf("get peer: %w", err) } - if existingPeer.IP.Equal(newIP.AsSlice()) { + if existingPeer.IP == newIP { return nil } @@ -2271,7 +2552,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti eventMeta := peer.EventMeta(dnsDomain) oldIP := peer.IP.String() - peer.IP = newIP.AsSlice() + peer.IP = newIP err = transaction.SavePeer(ctx, accountID, peer) if err != nil { return fmt.Errorf("save peer: %w", err) @@ -2284,6 +2565,84 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti return nil } +// UpdatePeerIPv6 updates the IPv6 overlay address of a peer, validating it's +// within the account's v6 network range and not already taken. +func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update) + if err != nil { + return fmt.Errorf("validate user permissions: %w", err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var updateNetworkMap bool + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var txErr error + updateNetworkMap, txErr = am.updatePeerIPv6InTransaction(ctx, transaction, accountID, peerID, newIPv6) + return txErr + }) + if err != nil { + return err + } + + if updateNetworkMap { + if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil { + return fmt.Errorf("notify network map controller: %w", err) + } + } + return nil +} + +// updatePeerIPv6InTransaction validates and applies an IPv6 address change within a store transaction. +func (am *DefaultAccountManager) updatePeerIPv6InTransaction(ctx context.Context, transaction store.Store, accountID, peerID string, newIPv6 netip.Addr) (bool, error) { + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, fmt.Errorf("get network: %w", err) + } + + if network.NetV6.IP == nil { + return false, status.Errorf(status.PreconditionFailed, "IPv6 is not configured for this account") + } + + if !network.NetV6.Contains(newIPv6.AsSlice()) { + return false, status.Errorf(status.InvalidArgument, "IP %s is not within the account IPv6 range %s", newIPv6, network.NetV6.String()) + } + + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, fmt.Errorf("get settings: %w", err) + } + + allowedPeers, err := am.buildIPv6AllowedPeers(ctx, transaction, accountID, settings) + if err != nil { + return false, err + } + if _, ok := allowedPeers[peerID]; !ok { + return false, status.Errorf(status.PreconditionFailed, "peer is not in any IPv6-enabled group") + } + + peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID) + if err != nil { + return false, fmt.Errorf("get peer: %w", err) + } + + if peer.IPv6.IsValid() && peer.IPv6 == newIPv6 { + return false, nil + } + + if err := am.checkIPv6Collision(ctx, transaction, accountID, peerID, newIPv6); err != nil { + return false, err + } + + peer.IPv6 = newIPv6 + if err := transaction.SavePeer(ctx, accountID, peer); err != nil { + return false, fmt.Errorf("save peer: %w", err) + } + + return true, nil +} + func (am *DefaultAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) { return am.Store.GetUserIDByPeerKey(ctx, store.LockingStrengthNone, peerKey) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 626ed222d..71af0645c 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -65,6 +65,7 @@ type Manager interface { DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error + UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 8f3b22ecc..7ffc41d73 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -1709,6 +1709,18 @@ func (mr *MockManagerMockRecorder) UpdatePeerIP(ctx, accountID, userID, peerID, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePeerIP", reflect.TypeOf((*MockManager)(nil).UpdatePeerIP), ctx, accountID, userID, peerID, newIP) } +func (m *MockManager) UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdatePeerIPv6", ctx, accountID, userID, peerID, newIPv6) + ret0, _ := ret[0].(error) + return ret0 +} + +func (mr *MockManagerMockRecorder) UpdatePeerIPv6(ctx, accountID, userID, peerID, newIPv6 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePeerIPv6", reflect.TypeOf((*MockManager)(nil).UpdatePeerIPv6), ctx, accountID, userID, peerID, newIPv6) +} + // UpdateToPrimaryAccount mocks base method. func (m *MockManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error { m.ctrl.T.Helper() diff --git a/management/server/account_test.go b/management/server/account_test.go index e259856e3..6bb875f99 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -160,7 +160,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { "peer-1": { ID: peerID1, Key: "peer-1-key", - IP: net.IP{100, 64, 0, 1}, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), + IPv6: netip.MustParseAddr("fd00::6440:1"), Name: peerID1, DNSLabel: peerID1, Status: &nbpeer.PeerStatus{ @@ -174,7 +175,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { "peer-2": { ID: peerID2, Key: "peer-2-key", - IP: net.IP{100, 64, 0, 1}, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), + IPv6: netip.MustParseAddr("fd00::6440:1"), Name: peerID2, DNSLabel: peerID2, Status: &nbpeer.PeerStatus{ @@ -198,7 +200,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { "peer-1": { ID: peerID1, Key: "peer-1-key", - IP: net.IP{100, 64, 0, 1}, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), + IPv6: netip.MustParseAddr("fd00::6440:1"), Name: peerID1, DNSLabel: peerID1, Status: &nbpeer.PeerStatus{ @@ -213,7 +216,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { "peer-2": { ID: peerID2, Key: "peer-2-key", - IP: net.IP{100, 64, 0, 1}, + IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), + IPv6: netip.MustParseAddr("fd00::6440:1"), Name: peerID2, DNSLabel: peerID2, Status: &nbpeer.PeerStatus{ @@ -237,7 +241,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-1": { // ID: peerID1, // Key: "peer-1-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID1, // DNSLabel: peerID1, // Status: &PeerStatus{ @@ -251,7 +255,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-2": { // ID: peerID2, // Key: "peer-2-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID2, // DNSLabel: peerID2, // Status: &PeerStatus{ @@ -265,7 +269,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-3": { // ID: peerID3, // Key: "peer-3-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID3, // DNSLabel: peerID3, // Status: &PeerStatus{ @@ -288,7 +292,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-1": { // ID: peerID1, // Key: "peer-1-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID1, // DNSLabel: peerID1, // Status: &PeerStatus{ @@ -302,7 +306,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-2": { // ID: peerID2, // Key: "peer-2-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID2, // DNSLabel: peerID2, // Status: &PeerStatus{ @@ -316,7 +320,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-3": { // ID: peerID3, // Key: "peer-3-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID3, // DNSLabel: peerID3, // Status: &PeerStatus{ @@ -339,7 +343,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-1": { // ID: peerID1, // Key: "peer-1-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID1, // DNSLabel: peerID1, // Status: &PeerStatus{ @@ -353,7 +357,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-2": { // ID: peerID2, // Key: "peer-2-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID2, // DNSLabel: peerID2, // Status: &PeerStatus{ @@ -367,7 +371,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // "peer-3": { // ID: peerID3, // Key: "peer-3-key", - // IP: net.IP{100, 64, 0, 1}, + // IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), // Name: peerID3, // DNSLabel: peerID3, // Status: &PeerStatus{ @@ -1084,7 +1088,7 @@ func TestAccountManager_AddPeer(t *testing.T) { t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) } - if !account.Network.Net.Contains(peer.IP) { + if !account.Network.Net.Contains(peer.IP.AsSlice()) { t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String()) } @@ -1148,7 +1152,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) } - if !account.Network.Net.Contains(peer.IP) { + if !account.Network.Net.Contains(peer.IP.AsSlice()) { t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String()) } @@ -2788,11 +2792,46 @@ func TestAccount_SetJWTGroups(t *testing.T) { account := &types.Account{ Id: "accountID", Peers: map[string]*nbpeer.Peer{ - "peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}, - "peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}, - "peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"}, - "peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"}, - "peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"}, + "peer1": { + ID: "peer1", + Key: "key1", + UserID: "user1", + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1"), + DNSLabel: "peer1.domain.test", + }, + "peer2": { + ID: "peer2", + Key: "key2", + UserID: "user1", + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2"), + DNSLabel: "peer2.domain.test", + }, + "peer3": { + ID: "peer3", + Key: "key3", + UserID: "user1", + IP: netip.AddrFrom4([4]byte{3, 3, 3, 3}), + IPv6: netip.MustParseAddr("fd00::3"), + DNSLabel: "peer3.domain.test", + }, + "peer4": { + ID: "peer4", + Key: "key4", + UserID: "user2", + IP: netip.AddrFrom4([4]byte{4, 4, 4, 4}), + IPv6: netip.MustParseAddr("fd00::4"), + DNSLabel: "peer4.domain.test", + }, + "peer5": { + ID: "peer5", + Key: "key5", + UserID: "user2", + IP: netip.AddrFrom4([4]byte{5, 5, 5, 5}), + IPv6: netip.MustParseAddr("fd00::5"), + DNSLabel: "peer5.domain.test", + }, }, Groups: map[string]*types.Group{ "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, @@ -3549,16 +3588,32 @@ func TestPropagateUserGroupMemberships(t *testing.T) { account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain}) require.NoError(t, err) - peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, Key: "key1", UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: account.Id, + Key: "key1", + UserID: initiatorId, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1"), + DNSLabel: "peer1.domain.test", + } err = manager.Store.AddPeerToAccount(ctx, peer1) require.NoError(t, err) - peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, Key: "key2", UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} + peer2 := &nbpeer.Peer{ + ID: "peer2", + AccountID: account.Id, + Key: "key2", + UserID: initiatorId, + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2"), + DNSLabel: "peer2.domain.test", + } err = manager.Store.AddPeerToAccount(ctx, peer2) require.NoError(t, err) t.Run("should skip propagation when the user has no groups", func(t *testing.T) { - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3574,7 +3629,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { user.AutoGroups = append(user.AutoGroups, group1.ID) require.NoError(t, manager.Store.SaveUser(ctx, user)) - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.True(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3612,7 +3667,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { }, true) require.NoError(t, err) - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.True(t, groupsUpdated) assert.True(t, groupChangesAffectPeers) @@ -3627,7 +3682,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { }) t.Run("should not update membership or account peers when no changes", func(t *testing.T) { - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3640,7 +3695,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { user.AutoGroups = []string{"group1"} require.NoError(t, manager.Store.SaveUser(ctx, user)) - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3754,11 +3809,10 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get account") - newIP, err := types.AllocatePeerIP(account.Network.Net, []net.IP{peer1.IP, peer2.IP}) + newIP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), []netip.Addr{peer1.IP, peer2.IP}) require.NoError(t, err, "unable to allocate new IP") - newAddr := netip.MustParseAddr(newIP.String()) - err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newAddr) + err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newIP) require.NoError(t, err, "unable to update peer IP") updatedPeer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, userID) @@ -3916,6 +3970,109 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi } } +func TestDefaultAccountManager_UpdateAccountSettings_IPv6EnabledGroups(t *testing.T) { + manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + ctx := context.Background() + accountID := account.Id + + // New accounts default to All group in IPv6EnabledGroups, so all 3 peers should have IPv6. + settings, err := manager.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + require.NoError(t, err) + require.NotEmpty(t, settings.IPv6EnabledGroups, "new account should have IPv6 enabled for All group") + + peers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.True(t, p.IPv6.IsValid(), "peer %s should have IPv6 with All group enabled", p.ID) + } + + // Create a group with only peer1 and peer2. + partialGroup := &types.Group{ + ID: "ipv6-partial-group", + AccountID: accountID, + Name: "IPv6Partial", + } + err = manager.Store.CreateGroup(ctx, partialGroup) + require.NoError(t, err) + require.NoError(t, manager.Store.AddPeerToGroup(ctx, accountID, peer1.ID, partialGroup.ID)) + require.NoError(t, manager.Store.AddPeerToGroup(ctx, accountID, peer2.ID, partialGroup.ID)) + + // Switch IPv6EnabledGroups to only the partial group. + updatedSettings, err := manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{partialGroup.ID}, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + assert.Equal(t, []string{partialGroup.ID}, updatedSettings.IPv6EnabledGroups) + + // peer1 and peer2 should have IPv6; peer3 should not. + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + peerMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, p := range peers { + peerMap[p.ID] = p + } + assert.True(t, peerMap[peer1.ID].IPv6.IsValid(), "peer1 in partial group should keep IPv6") + assert.True(t, peerMap[peer2.ID].IPv6.IsValid(), "peer2 in partial group should keep IPv6") + assert.False(t, peerMap[peer3.ID].IPv6.IsValid(), "peer3 not in partial group should lose IPv6") + + // Clearing all groups disables IPv6 for everyone. + updatedSettings, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{}, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + assert.Empty(t, updatedSettings.IPv6EnabledGroups) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.False(t, p.IPv6.IsValid(), "peer %s should have no IPv6 when groups cleared", p.ID) + } + + // Re-enabling with the partial group should allocate IPv6 only for peer1 and peer2. + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{partialGroup.ID}, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + peerMap = make(map[string]*nbpeer.Peer, len(peers)) + for _, p := range peers { + peerMap[p.ID] = p + } + assert.True(t, peerMap[peer1.ID].IPv6.IsValid(), "peer1 should get IPv6 back") + assert.True(t, peerMap[peer2.ID].IPv6.IsValid(), "peer2 should get IPv6 back") + assert.False(t, peerMap[peer3.ID].IPv6.IsValid(), "peer3 still excluded") + + // No-op update with the same groups should not cause errors. + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{partialGroup.ID}, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + // Setting a nonexistent group ID should fail. + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + PeerLoginExpirationEnabled: true, + IPv6EnabledGroups: []string{"nonexistent-group-id"}, + Extra: &types.ExtraSettings{}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "does not exist") +} + func TestUpdateUserAuthWithSingleMode(t *testing.T) { t.Run("sets defaults and overrides domain from store", func(t *testing.T) { ctrl := gomock.NewController(t) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index ddc3e00c3..2388115ff 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -231,6 +231,10 @@ const ( DomainDeleted Activity = 119 // DomainValidated indicates that a custom domain was validated DomainValidated Activity = 120 + // AccountIPv6Enabled indicates that a user enabled IPv6 overlay for the account + AccountIPv6Enabled Activity = 121 + // AccountIPv6Disabled indicates that a user disabled IPv6 overlay for the account + AccountIPv6Disabled Activity = 122 AccountDeleted Activity = 99999 ) @@ -347,6 +351,9 @@ var activityMap = map[Activity]Code{ AccountAutoUpdateAlwaysEnabled: {"Account auto-update always enabled", "account.setting.auto.update.always.enable"}, AccountAutoUpdateAlwaysDisabled: {"Account auto-update always disabled", "account.setting.auto.update.always.disable"}, + AccountIPv6Enabled: {"Account IPv6 overlay enabled", "account.setting.ipv6.enable"}, + AccountIPv6Disabled: {"Account IPv6 overlay disabled", "account.setting.ipv6.disable"}, + IdentityProviderCreated: {"Identity provider created", "identityprovider.create"}, IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"}, IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"}, diff --git a/management/server/group.go b/management/server/group.go index e1d05171e..870a441ac 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -174,6 +174,10 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -278,37 +282,17 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us var globalErr error groupIDs := make([]string, 0, len(groups)) for _, newGroup := range groups { - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { - return err - } - - newGroup.AccountID = accountID - - if err = transaction.UpdateGroup(ctx, newGroup); err != nil { - return err - } - - err = transaction.IncrementNetworkSerial(ctx, accountID) - if err != nil { - return err - } - - events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) - eventsToStore = append(eventsToStore, events...) - - groupIDs = append(groupIDs, newGroup.ID) - - return nil - }) + events, err := am.updateSingleGroup(ctx, accountID, userID, newGroup) if err != nil { log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) if len(groups) == 1 { return err } globalErr = errors.Join(globalErr, err) - // continue updating other groups + continue } + eventsToStore = append(eventsToStore, events...) + groupIDs = append(groupIDs, newGroup.ID) } updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) @@ -327,6 +311,33 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us return globalErr } +func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) ([]func(), error) { + var events []func() + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + newGroup.AccountID = accountID + + if err := transaction.UpdateGroup(ctx, newGroup); err != nil { + return err + } + + if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil { + return err + } + + if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return err + } + + events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + return nil + }) + return events, err +} + // prepareGroupEvents prepares a list of event functions to be stored. func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() { var eventsToStore []func() @@ -458,6 +469,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, groupIDsToDelete); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -486,6 +501,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -552,6 +571,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { diff --git a/management/server/group_ipv6_test.go b/management/server/group_ipv6_test.go new file mode 100644 index 000000000..e4603c879 --- /dev/null +++ b/management/server/group_ipv6_test.go @@ -0,0 +1,125 @@ +package server + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +// TestGroupIPv6Assignment verifies that peers gain or lose IPv6 addresses +// when they are added to or removed from an IPv6-enabled group. +func TestGroupIPv6Assignment(t *testing.T) { + am, _, err := createManager(t) + require.NoError(t, err) + + ctx := context.Background() + userID := groupAdminUserID + + account, err := createAccount(am, "ipv6-grp-test", userID, "ipv6test.example.com") + require.NoError(t, err) + + // Allocate IPv6 subnet for the account + account.Network.NetV6 = types.AllocateIPv6Subnet(rand.New(rand.NewSource(time.Now().UnixNano()))) + require.NoError(t, am.Store.SaveAccount(ctx, account)) + + // Create setup key + setupKey, err := am.CreateSetupKey(ctx, account.Id, "ipv6-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false) + require.NoError(t, err) + + // Create an IPv6-enabled group + ipv6GroupID := "ipv6-enabled-grp" + err = am.CreateGroup(ctx, account.Id, userID, &types.Group{ + ID: ipv6GroupID, + Name: "IPv6 Enabled", + Issued: types.GroupIssuedAPI, + Peers: []string{}, + }) + require.NoError(t, err) + + // Enable IPv6 on that group + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id) + require.NoError(t, err) + settings.IPv6EnabledGroups = []string{ipv6GroupID} + require.NoError(t, am.Store.SaveAccountSettings(ctx, account.Id, settings)) + + // Register a peer (will be in "All" group, not the IPv6 group) + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{ + Key: key.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"}, + }, false) + require.NoError(t, err) + assert.False(t, peer.IPv6.IsValid(), "peer should not have IPv6 before joining an IPv6-enabled group") + + t.Run("GroupAddPeer assigns IPv6", func(t *testing.T) { + err := am.GroupAddPeer(ctx, account.Id, ipv6GroupID, peer.ID) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.True(t, p.IPv6.IsValid(), "peer should have an IPv6 address after joining the group") + }) + + t.Run("GroupDeletePeer clears IPv6", func(t *testing.T) { + err := am.GroupDeletePeer(ctx, account.Id, ipv6GroupID, peer.ID) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.False(t, p.IPv6.IsValid(), "peer should not have IPv6 after removal from the group") + }) + + t.Run("UpdateGroup with peer addition assigns IPv6", func(t *testing.T) { + grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID) + require.NoError(t, err) + + grp.Peers = append(grp.Peers, peer.ID) + err = am.UpdateGroup(ctx, account.Id, userID, grp) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.True(t, p.IPv6.IsValid(), "peer should have IPv6 after UpdateGroup adds it") + }) + + t.Run("UpdateGroup with peer removal clears IPv6", func(t *testing.T) { + grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID) + require.NoError(t, err) + + grp.Peers = []string{} + err = am.UpdateGroup(ctx, account.Id, userID, grp) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.False(t, p.IPv6.IsValid(), "peer should lose IPv6 after UpdateGroup removes it") + }) + + t.Run("non-IPv6 group changes do not affect IPv6", func(t *testing.T) { + err := am.CreateGroup(ctx, account.Id, userID, &types.Group{ + ID: "regular-grp", + Name: "Regular Group", + Issued: types.GroupIssuedAPI, + Peers: []string{}, + }) + require.NoError(t, err) + + err = am.GroupAddPeer(ctx, account.Id, "regular-grp", peer.ID) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.False(t, p.IPv6.IsValid(), "peer should not get IPv6 from a non-IPv6 group") + }) +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 5821b90a3..22fda2671 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "errors" "fmt" - "net" "net/netip" "strconv" "sync" @@ -999,10 +998,10 @@ func Test_AddPeerAndAddToAll(t *testing.T) { assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers)) } -func uint32ToIP(n uint32) net.IP { - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, n) - return ip +func uint32ToIP(n uint32) netip.Addr { + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + return netip.AddrFrom4(b) } func Test_IncrementNetworkSerial(t *testing.T) { diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index cc5567e3d..31820b9fb 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -4,10 +4,13 @@ import ( "context" "encoding/json" "fmt" + "math" "net/http" "net/netip" "time" + log "github.com/sirupsen/logrus" + "github.com/gorilla/mux" goversion "github.com/hashicorp/go-version" @@ -29,7 +32,9 @@ const ( // MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16) MinNetworkBitsIPv4 = 28 // MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges - MinNetworkBitsIPv6 = 120 + MinNetworkBitsIPv6 = 120 + // MaxNetworkSizeIPv6 is the largest allowed IPv6 prefix (smallest number) + MaxNetworkSizeIPv6 = 48 disableAutoUpdate = "disabled" autoUpdateLatestVersion = "latest" ) @@ -76,12 +81,35 @@ func validateMinimumSize(prefix netip.Prefix) error { if addr.Is4() && prefix.Bits() > MinNetworkBitsIPv4 { return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv4", MinNetworkBitsIPv4) } - if addr.Is6() && prefix.Bits() > MinNetworkBitsIPv6 { - return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv6", MinNetworkBitsIPv6) + if addr.Is6() { + if prefix.Bits() > MinNetworkBitsIPv6 { + return status.Errorf(status.InvalidArgument, "network range too small: minimum size is /%d for IPv6", MinNetworkBitsIPv6) + } + if prefix.Bits() < MaxNetworkSizeIPv6 { + return status.Errorf(status.InvalidArgument, "network range too large: maximum size is /%d for IPv6", MaxNetworkSizeIPv6) + } } return nil } +func (h *handler) parseAndValidateNetworkRange(ctx context.Context, accountID, userID, rangeStr string, requireV6 bool) (netip.Prefix, error) { + prefix, err := netip.ParsePrefix(rangeStr) + if err != nil { + return netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err) + } + prefix = prefix.Masked() + if requireV6 && !prefix.Addr().Is6() { + return netip.Prefix{}, status.Errorf(status.InvalidArgument, "network range must be an IPv6 address") + } + if !requireV6 && prefix.Addr().Is6() { + return netip.Prefix{}, status.Errorf(status.InvalidArgument, "network range must be an IPv4 address") + } + if err := h.validateNetworkRange(ctx, accountID, userID, prefix); err != nil { + return netip.Prefix{}, err + } + return prefix, nil +} + func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID string, networkRange netip.Prefix) error { if !networkRange.IsValid() { return nil @@ -117,9 +145,12 @@ func (h *handler) validateCapacity(ctx context.Context, accountID, userID string } func calculateMaxHosts(prefix netip.Prefix) int64 { - availableAddresses := prefix.Addr().BitLen() - prefix.Bits() - maxHosts := int64(1) << availableAddresses + hostBits := prefix.Addr().BitLen() - prefix.Bits() + if hostBits >= 63 { + return math.MaxInt64 + } + maxHosts := int64(1) << hostBits if prefix.Addr().Is4() { maxHosts -= 2 // network and broadcast addresses } @@ -164,6 +195,24 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { } resp := toAccountResponse(accountID, settings, meta, onboarding) + + // Populate effective network ranges when settings don't have explicit overrides. + if resp.Settings.NetworkRange == nil || resp.Settings.NetworkRangeV6 == nil { + v4, v6, err := h.settingsManager.GetEffectiveNetworkRanges(r.Context(), accountID) + if err != nil { + log.WithContext(r.Context()).Warnf("get effective network ranges: %v", err) + } else { + if resp.Settings.NetworkRange == nil && v4.IsValid() { + s := v4.String() + resp.Settings.NetworkRange = &s + } + if resp.Settings.NetworkRangeV6 == nil && v6.IsValid() { + s := v6.String() + resp.Settings.NetworkRangeV6 = &s + } + } + } + util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -228,6 +277,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS if req.Settings.AutoUpdateAlways != nil { returnSettings.AutoUpdateAlways = *req.Settings.AutoUpdateAlways } + if req.Settings.Ipv6EnabledGroups != nil { + returnSettings.IPv6EnabledGroups = *req.Settings.Ipv6EnabledGroups + } return returnSettings, nil } @@ -262,18 +314,23 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" { - prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange) + prefix, err := h.parseAndValidateNetworkRange(r.Context(), accountID, userID, *req.Settings.NetworkRange, false) if err != nil { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w) - return - } - if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil { util.WriteError(r.Context(), err, w) return } settings.NetworkRange = prefix } + if req.Settings.NetworkRangeV6 != nil && *req.Settings.NetworkRangeV6 != "" { + prefix, err := h.parseAndValidateNetworkRange(r.Context(), accountID, userID, *req.Settings.NetworkRangeV6, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + settings.NetworkRangeV6 = prefix + } + var onboarding *types.AccountOnboarding if req.Onboarding != nil { onboarding = &types.AccountOnboarding{ @@ -352,6 +409,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A DnsDomain: &settings.DNSDomain, AutoUpdateVersion: &settings.AutoUpdateVersion, AutoUpdateAlways: &settings.AutoUpdateAlways, + Ipv6EnabledGroups: &settings.IPv6EnabledGroups, EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, LocalAuthDisabled: &settings.LocalAuthDisabled, } @@ -360,6 +418,10 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A networkRangeStr := settings.NetworkRange.String() apiSettings.NetworkRange = &networkRangeStr } + if settings.NetworkRangeV6.IsValid() { + networkRangeV6Str := settings.NetworkRangeV6.String() + apiSettings.NetworkRangeV6 = &networkRangeV6Str + } apiOnboarding := api.AccountOnboarding{ OnboardingFlowPending: onboarding.OnboardingFlowPending, diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 739dfe2f6..fc1517a30 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -5,8 +5,10 @@ import ( "context" "encoding/json" "io" + "math" "net/http" "net/http/httptest" + "net/netip" "testing" "time" @@ -31,6 +33,10 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { GetSettings(gomock.Any(), account.Id, "test_user"). Return(account.Settings, nil). AnyTimes() + settingsMockManager.EXPECT(). + GetEffectiveNetworkRanges(gomock.Any(), account.Id). + Return(netip.Prefix{}, netip.Prefix{}, nil). + AnyTimes() return &handler{ accountManager: &mock_server.MockAccountManager{ @@ -336,3 +342,27 @@ func TestAccounts_AccountsHandler(t *testing.T) { }) } } + +func TestCalculateMaxHosts(t *testing.T) { + tests := []struct { + name string + prefix string + min int64 + }{ + {"v4 /24", "100.64.0.0/24", 254}, + {"v4 /16", "100.64.0.0/16", 65534}, + {"v4 /28", "100.64.0.0/28", 14}, + {"v6 /64", "fd00::/64", math.MaxInt64}, + {"v6 /120", "fd00::/120", 256}, + {"v6 /112", "fd00::/112", 65536}, + {"v6 /48", "fd00::/48", math.MaxInt64}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefix := netip.MustParsePrefix(tt.prefix) + got := calculateMaxHosts(prefix) + assert.Equal(t, tt.min, got) + }) + } +} diff --git a/management/server/http/handlers/dns/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go index bce1c4b78..dbbdf3ed9 100644 --- a/management/server/http/handlers/dns/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -3,7 +3,10 @@ package dns import ( "encoding/json" "fmt" + "net" "net/http" + "strconv" + "strings" "github.com/gorilla/mux" log "github.com/sirupsen/logrus" @@ -201,7 +204,11 @@ func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.R func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) { var nsList []nbdns.NameServer for _, apiNS := range apiNSList { - parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s:%d", apiNS.NsType, apiNS.Ip, apiNS.Port)) + host, err := unwrapBracketedHost(apiNS.Ip) + if err != nil { + return nil, err + } + parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s", apiNS.NsType, net.JoinHostPort(host, strconv.Itoa(apiNS.Port)))) if err != nil { return nil, err } @@ -211,6 +218,18 @@ func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) { return nsList, nil } +// unwrapBracketedHost returns ip with surrounding brackets stripped, rejecting +// inputs with mismatched brackets. +func unwrapBracketedHost(ip string) (string, error) { + if !strings.ContainsAny(ip, "[]") { + return ip, nil + } + if !strings.HasPrefix(ip, "[") || !strings.HasSuffix(ip, "]") { + return "", fmt.Errorf("malformed bracketed address: %s", ip) + } + return ip[1 : len(ip)-1], nil +} + func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.NameserverGroup { var nsList []api.Nameserver for _, ns := range serverNSGroup.NameServers { diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index 4716782f3..a165f009b 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -233,3 +233,37 @@ func TestNameserversHandlers(t *testing.T) { }) } } + +func TestToServerNSList_IPv6(t *testing.T) { + tests := []struct { + name string + input []api.Nameserver + expectIP netip.Addr + }{ + { + name: "IPv4", + input: []api.Nameserver{ + {Ip: "1.1.1.1", NsType: "udp", Port: 53}, + }, + expectIP: netip.MustParseAddr("1.1.1.1"), + }, + { + name: "IPv6", + input: []api.Nameserver{ + {Ip: "2001:4860:4860::8888", NsType: "udp", Port: 53}, + }, + expectIP: netip.MustParseAddr("2001:4860:4860::8888"), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := toServerNSList(tc.input) + assert.NoError(t, err) + if assert.Len(t, result, 1) { + assert.Equal(t, tc.expectIP, result[0].IP) + assert.Equal(t, 53, result[0].Port) + } + }) + } +} diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index c7b4cbcdd..57e238630 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -7,8 +7,8 @@ import ( "errors" "fmt" "io" - "net" "net/http" + "net/netip" "net/http/httptest" "strings" "testing" @@ -29,8 +29,8 @@ import ( ) var TestPeers = map[string]*nbpeer.Peer{ - "A": {Key: "A", ID: "peer-A-ID", IP: net.ParseIP("100.100.100.100")}, - "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, + "A": {Key: "A", ID: "peer-A-ID", IP: netip.MustParseAddr("100.100.100.100")}, + "B": {Key: "B", ID: "peer-B-ID", IP: netip.MustParseAddr("200.200.200.200")}, } func initGroupTestData(initGroups ...*types.Group) *handler { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index bf6937a49..91026a374 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -220,6 +220,18 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri } } + if req.Ipv6 != nil { + v6Addr, err := parseIPv6(req.Ipv6) + if err != nil { + util.WriteError(ctx, status.Errorf(status.InvalidArgument, "%v", err), w) + return + } + if err = h.accountManager.UpdatePeerIPv6(ctx, accountID, userID, peerID, v6Addr); err != nil { + util.WriteError(ctx, err, w) + return + } + } + peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) if err != nil { util.WriteError(ctx, err, w) @@ -355,6 +367,21 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM } } +func parseIPv6(s *string) (netip.Addr, error) { + if s == nil { + return netip.Addr{}, fmt.Errorf("IPv6 address is nil") + } + addr, err := netip.ParseAddr(*s) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid IPv6 address %s: %w", *s, err) + } + addr = addr.Unmap() + if !addr.Is6() { + return netip.Addr{}, fmt.Errorf("address %s is not IPv6", *s) + } + return addr, nil +} + // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) @@ -529,6 +556,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee GeonameId: int(peer.Location.GeoNameID), Id: peer.ID, Ip: peer.IP.String(), + Ipv6: peerIPv6String(peer), LastSeen: peer.Status.LastSeen, Name: peer.Name, Os: peer.Meta.OS, @@ -547,6 +575,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD Id: peer.ID, Name: peer.Name, Ip: peer.IP.String(), + Ipv6: peerIPv6String(peer), ConnectionIp: peer.Location.ConnectionIP.String(), Connected: peer.Status.Connected, LastSeen: peer.Status.LastSeen, @@ -601,6 +630,7 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn Id: peer.ID, Name: peer.Name, Ip: peer.IP.String(), + Ipv6: peerIPv6String(peer), ConnectionIp: peer.Location.ConnectionIP.String(), Connected: peer.Status.Connected, LastSeen: peer.Status.LastSeen, @@ -677,3 +707,11 @@ func fqdnList(extraLabels []string, dnsDomain string) []string { } return fqdnList } + +func peerIPv6String(peer *nbpeer.Peer) *string { + if !peer.IPv6.IsValid() { + return nil + } + s := peer.IPv6.String() + return &s +} diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 6b3616597..9db095c8d 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -146,7 +146,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { UpdatePeerIPFunc: func(_ context.Context, accountID, userID, peerID string, newIP netip.Addr) error { for _, peer := range peers { if peer.ID == peerID { - peer.IP = net.IP(newIP.AsSlice()) + peer.IP = newIP return nil } } @@ -228,7 +228,8 @@ func TestGetPeers(t *testing.T) { peer := &nbpeer.Peer{ ID: testPeerID, Key: "key", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "PeerName", LoginExpirationEnabled: false, @@ -368,7 +369,8 @@ func TestGetAccessiblePeers(t *testing.T) { peer1 := &nbpeer.Peer{ ID: "peer1", Key: "key1", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00:1234::1"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "peer1", LoginExpirationEnabled: false, @@ -378,7 +380,8 @@ func TestGetAccessiblePeers(t *testing.T) { peer2 := &nbpeer.Peer{ ID: "peer2", Key: "key2", - IP: net.ParseIP("100.64.0.2"), + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00:1234::2"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "peer2", LoginExpirationEnabled: false, @@ -388,7 +391,8 @@ func TestGetAccessiblePeers(t *testing.T) { peer3 := &nbpeer.Peer{ ID: "peer3", Key: "key3", - IP: net.ParseIP("100.64.0.3"), + IP: netip.MustParseAddr("100.64.0.3"), + IPv6: netip.MustParseAddr("fd00:1234::3"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "peer3", LoginExpirationEnabled: false, @@ -532,7 +536,8 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { testPeer := &nbpeer.Peer{ ID: testPeerID, Key: "key", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, Name: "test-host@netbird.io", LoginExpirationEnabled: false, diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index b7a63b104..9a78620c9 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -5,9 +5,9 @@ import ( "context" "fmt" "io" - "net" "net/http" "net/http/httptest" + "net/netip" "os" "strconv" "testing" @@ -133,7 +133,7 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se ID: fmt.Sprintf("oldpeer-%d", i), DNSLabel: fmt.Sprintf("oldpeer-%d", i), Key: peerKey.PublicKey().String(), - IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + IP: netip.MustParseAddr(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: TestUserId, } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ac4d0c6d6..08091d4b7 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -63,6 +63,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error + UpdatePeerIPv6Func func(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error @@ -539,6 +540,13 @@ func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userI return status.Errorf(codes.Unimplemented, "method UpdatePeerIP is not implemented") } +func (am *MockAccountManager) UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error { + if am.UpdatePeerIPv6Func != nil { + return am.UpdatePeerIPv6Func(ctx, accountID, userID, peerID, newIPv6) + } + return status.Errorf(codes.Unimplemented, "method UpdatePeerIPv6 is not implemented") +} + // CreateRoute mock implementation of CreateRoute from server.AccountManager interface func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) { if am.CreateRouteFunc != nil { diff --git a/management/server/peer.go b/management/server/peer.go index 25c6ecd8c..8a39fbbb8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -6,6 +6,7 @@ import ( b64 "encoding/base64" "fmt" "net" + "net/netip" "slices" "strings" "time" @@ -521,6 +522,27 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri return account.Network.Copy(), err } +// peerWillHaveIPv6 checks whether the peer's future group memberships +// (auto-groups + allGroupID) overlap with IPv6EnabledGroups. +func peerWillHaveIPv6(settings *types.Settings, groupsToAdd []string, allGroupID string) bool { + enabledSet := make(map[string]struct{}, len(settings.IPv6EnabledGroups)) + for _, gid := range settings.IPv6EnabledGroups { + enabledSet[gid] = struct{}{} + } + + if allGroupID != "" { + if _, ok := enabledSet[allGroupID]; ok { + return true + } + } + for _, gid := range groupsToAdd { + if _, ok := enabledSet[gid]; ok { + return true + } + } + return false +} + type peerAddAuthConfig struct { AccountID string SetupKeyID string @@ -715,8 +737,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe maxAttempts := 10 for attempt := 1; attempt <= maxAttempts; attempt++ { - var freeIP net.IP - freeIP, err = types.AllocateRandomPeerIP(network.Net) + netPrefix, err := netip.ParsePrefix(network.Net.String()) + if err != nil { + return nil, nil, nil, fmt.Errorf("parse network prefix: %w", err) + } + freeIP, err := types.AllocateRandomPeerIP(netPrefix) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err) } @@ -736,6 +761,29 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe newPeer.DNSLabel = freeLabel newPeer.IP = freeIP + if len(settings.IPv6EnabledGroups) > 0 && network.NetV6.IP != nil { + var allGroupID string + if !peer.ProxyMeta.Embedded { + allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, "All") + if err != nil { + log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err) + } else { + allGroupID = allGroup.ID + } + } + if peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) { + v6Prefix, err := netip.ParsePrefix(network.NetV6.String()) + if err != nil { + return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err) + } + freeIPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) + if err != nil { + return nil, nil, nil, fmt.Errorf("allocate peer IPv6: %w", err) + } + newPeer.IPv6 = freeIPv6 + } + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = transaction.AddPeerToAccount(ctx, newPeer) if err != nil { @@ -805,10 +853,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err) } - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) - } - if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } @@ -834,21 +878,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return p, nmap, pc, err } -func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { - ip = ip.To4() +func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) { + if !ip.Is4() { + return "", fmt.Errorf("DNS label generation requires an IPv4 address, got %s", ip) + } + b := ip.As4() dnsName, err := nbdns.GetParsedDomainLabel(peerHostName) if err != nil { return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err) } - return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil + return fmt.Sprintf("%s-%d-%d", dnsName, b[2], b[3]), nil } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { var peer *nbpeer.Peer - var updated, versionChanged bool + var updated, versionChanged, ipv6CapabilityChanged bool var err error var postureChecks []*posture.Checks var peerGroupIDs []string @@ -884,7 +931,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return err } + oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta) + ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) if updated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) @@ -908,7 +957,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil, nil, nil, 0, err } - if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { + if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) { err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) if err != nil { return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err) @@ -958,6 +1007,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer var peer *nbpeer.Peer var updateRemotePeers bool var isPeerUpdated bool + var ipv6CapabilityChanged bool var postureChecks []*posture.Checks var peerGroupIDs []string @@ -997,7 +1047,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return err } + oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) isPeerUpdated, _ = peer.UpdateMetaIfNew(login.Meta) + ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) if isPeerUpdated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true @@ -1035,7 +1087,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } - if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { + if updateRemotePeers || isStatusChanged || ipv6CapabilityChanged || (isPeerUpdated && len(postureChecks) > 0) { err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) if err != nil { return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err) diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index db392ddda..17df761a1 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -11,6 +11,12 @@ import ( "github.com/netbirdio/netbird/shared/management/http/api" ) +// Peer capability constants mirror the proto enum values. +const ( + PeerCapabilitySourcePrefixes int32 = 1 + PeerCapabilityIPv6Overlay int32 = 2 +) + // Peer represents a machine connected to the network. // The Peer is a WireGuard peer identified by a public key type Peer struct { @@ -21,7 +27,9 @@ type Peer struct { // WireGuard public key Key string // uniqueness index (check migrations) // IP address of the Peer - IP net.IP `gorm:"serializer:json"` // uniqueness index per accountID (check migrations) + IP netip.Addr `gorm:"serializer:json"` // uniqueness index per accountID (check migrations) + // IPv6 overlay address of the Peer, zero value if IPv6 is not enabled for the account. + IPv6 netip.Addr `gorm:"serializer:json"` // Meta is a Peer system meta data Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // ProxyMeta is metadata related to proxy peers @@ -115,6 +123,7 @@ type Flags struct { DisableFirewall bool BlockLANAccess bool BlockInbound bool + DisableIPv6 bool LazyConnectionEnabled bool } @@ -138,6 +147,7 @@ type PeerSystemMeta struct { //nolint:revive Environment Environment `gorm:"serializer:json"` Flags Flags `gorm:"serializer:json"` Files []File `gorm:"serializer:json"` + Capabilities []int32 `gorm:"serializer:json"` } func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { @@ -182,7 +192,8 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { p.SystemManufacturer == other.SystemManufacturer && p.Environment.Cloud == other.Environment.Cloud && p.Environment.Platform == other.Environment.Platform && - p.Flags.isEqual(other.Flags) + p.Flags.isEqual(other.Flags) && + capabilitiesEqual(p.Capabilities, other.Capabilities) } func (p PeerSystemMeta) isEmpty() bool { @@ -210,6 +221,37 @@ func (p *Peer) AddedWithSSOLogin() bool { return p.UserID != "" } +// HasCapability reports whether the peer has the given capability. +func (p *Peer) HasCapability(capability int32) bool { + return slices.Contains(p.Meta.Capabilities, capability) +} + +// SupportsIPv6 reports whether the peer supports IPv6 overlay. +func (p *Peer) SupportsIPv6() bool { + return !p.Meta.Flags.DisableIPv6 && p.HasCapability(PeerCapabilityIPv6Overlay) +} + +// SupportsSourcePrefixes reports whether the peer reads SourcePrefixes. +func (p *Peer) SupportsSourcePrefixes() bool { + return p.HasCapability(PeerCapabilitySourcePrefixes) +} + +func capabilitiesEqual(a, b []int32) bool { + if len(a) != len(b) { + return false + } + set := make(map[int32]struct{}, len(a)) + for _, c := range a { + set[c] = struct{}{} + } + for _, c := range b { + if _, ok := set[c]; !ok { + return false + } + } + return true +} + // Copy copies Peer object func (p *Peer) Copy() *Peer { peerStatus := p.Status @@ -221,6 +263,7 @@ func (p *Peer) Copy() *Peer { AccountID: p.AccountID, Key: p.Key, IP: p.IP, + IPv6: p.IPv6, Meta: p.Meta, Name: p.Name, DNSLabel: p.DNSLabel, @@ -323,9 +366,13 @@ func (p *Peer) FQDN(dnsDomain string) string { // EventMeta returns activity event meta related to the peer func (p *Peer) EventMeta(dnsDomain string) map[string]any { - return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt, + meta := map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt, "location_city_name": p.Location.CityName, "location_country_code": p.Location.CountryCode, "location_geo_name_id": p.Location.GeoNameID, "location_connection_ip": p.Location.ConnectionIP} + if p.IPv6.IsValid() { + meta["ipv6"] = p.IPv6.String() + } + return meta } // Copy PeerStatus @@ -369,5 +416,6 @@ func (f Flags) isEqual(other Flags) bool { f.DisableFirewall == other.DisableFirewall && f.BlockLANAccess == other.BlockLANAccess && f.BlockInbound == other.BlockInbound && - f.LazyConnectionEnabled == other.LazyConnectionEnabled + f.LazyConnectionEnabled == other.LazyConnectionEnabled && + f.DisableIPv6 == other.DisableIPv6 } diff --git a/management/server/peer/peer_test.go b/management/server/peer/peer_test.go index 1aa3f6ffc..c5b512069 100644 --- a/management/server/peer/peer_test.go +++ b/management/server/peer/peer_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -141,3 +142,25 @@ func TestFlags_IsEqual(t *testing.T) { }) } } + +func TestPeerCapabilities(t *testing.T) { + tests := []struct { + name string + capabilities []int32 + ipv6 bool + srcPrefixes bool + }{ + {"no capabilities", nil, false, false}, + {"only source prefixes", []int32{PeerCapabilitySourcePrefixes}, false, true}, + {"only ipv6", []int32{PeerCapabilityIPv6Overlay}, true, false}, + {"both", []int32{PeerCapabilitySourcePrefixes, PeerCapabilityIPv6Overlay}, true, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Peer{Meta: PeerSystemMeta{Capabilities: tt.capabilities}} + assert.Equal(t, tt.ipv6, p.SupportsIPv6()) + assert.Equal(t, tt.srcPrefixes, p.SupportsSourcePrefixes()) + }) + } +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 36809d354..07acf865f 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -754,7 +754,8 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou ID: fmt.Sprintf("peer-%d", i), DNSLabel: fmt.Sprintf("peer-%d", i), Key: peerKey.PublicKey().String(), - IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + IP: netip.MustParseAddr(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + IPv6: netip.MustParseAddr(fmt.Sprintf("fd00::%d", i+1)), Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: regularUser, } @@ -783,7 +784,15 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou account.Networks = append(account.Networks, network) ips := account.GetTakenIPs() - peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) + peerIP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, nil, "", "", err + } + v6Prefix, err := netip.ParsePrefix(account.Network.NetV6.String()) + if err != nil { + return nil, nil, "", "", err + } + peerIPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, nil, "", "", err } @@ -794,6 +803,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou DNSLabel: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), Key: peerKey.PublicKey().String(), IP: peerIP, + IPv6: peerIPv6, Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: regularUser, Meta: nbpeer.PeerSystemMeta{ @@ -1068,7 +1078,8 @@ func TestToSyncResponse(t *testing.T) { }, } peer := &nbpeer.Peer{ - IP: net.ParseIP("192.168.1.1"), + IP: netip.MustParseAddr("192.168.1.1"), + IPv6: netip.MustParseAddr("fd00::1"), SSHEnabled: true, Key: "peer-key", DNSLabel: "peer1", @@ -1079,9 +1090,21 @@ func TestToSyncResponse(t *testing.T) { Signature: "turn-pass", } networkMap := &types.NetworkMap{ - Network: &types.Network{Net: *ipnet, Serial: 1000}, - Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, - OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, + Network: &types.Network{Net: *ipnet, Serial: 1000}, + Peers: []*nbpeer.Peer{{ + IP: netip.MustParseAddr("192.168.1.2"), + IPv6: netip.MustParseAddr("fd00::2"), + Key: "peer2-key", + DNSLabel: "peer2", + SSHEnabled: true, + SSHKey: "peer2-ssh-key"}}, + OfflinePeers: []*nbpeer.Peer{{ + IP: netip.MustParseAddr("192.168.1.3"), + IPv6: netip.MustParseAddr("fd00::3"), + Key: "peer3-key", + DNSLabel: "peer3", + SSHEnabled: true, + SSHKey: "peer3-ssh-key"}}, Routes: []*nbroute.Route{ { ID: "route1", @@ -1228,6 +1251,7 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort()) // assert network map Firewall assert.Equal(t, 1, len(response.NetworkMap.FirewallRules)) + //nolint:staticcheck // testing backward-compatible field assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP) assert.Equal(t, proto.RuleDirection_IN, response.NetworkMap.FirewallRules[0].Direction) assert.Equal(t, proto.RuleAction_ACCEPT, response.NetworkMap.FirewallRules[0].Action) @@ -1290,7 +1314,8 @@ func Test_RegisterPeerByUser(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - IP: net.IP{123, 123, 123, 123}, + IP: netip.AddrFrom4([4]byte{123, 123, 123, 123}), + IPv6: netip.MustParseAddr("fd00::7b:7b:7b:7b"), Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", GoOS: "linux", @@ -1378,7 +1403,8 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { newPeerTemplate := &nbpeer.Peer{ AccountID: existingAccountID, UserID: "", - IP: net.IP{123, 123, 123, 123}, + IP: netip.AddrFrom4([4]byte{123, 123, 123, 123}), + IPv6: netip.MustParseAddr("fd00::7b:7b:7b:7b"), Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", GoOS: "linux", @@ -1539,7 +1565,8 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { AccountID: existingAccountID, Key: "newPeerKey", UserID: "", - IP: net.IP{123, 123, 123, 123}, + IP: netip.AddrFrom4([4]byte{123, 123, 123, 123}), + IPv6: netip.MustParseAddr("fd00::7b:7b:7b:7b"), Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", GoOS: "linux", @@ -1624,7 +1651,8 @@ func Test_LoginPeer(t *testing.T) { newPeerTemplate := &nbpeer.Peer{ AccountID: existingAccountID, UserID: "", - IP: net.IP{123, 123, 123, 123}, + IP: netip.AddrFrom4([4]byte{123, 123, 123, 123}), + IPv6: netip.MustParseAddr("fd00::7b:7b:7b:7b"), Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", GoOS: "linux", @@ -2126,14 +2154,16 @@ func Test_DeletePeer(t *testing.T) { ID: "peer1", AccountID: accountID, Key: "key1", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1"), DNSLabel: "peer1.test", }, "peer2": { ID: "peer2", AccountID: accountID, Key: "key2", - IP: net.IP{2, 2, 2, 2}, + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2"), DNSLabel: "peer2.test", }, } @@ -2730,6 +2760,67 @@ func TestProcessPeerAddAuth(t *testing.T) { }) } +func TestPeerWillHaveIPv6(t *testing.T) { + settings := &types.Settings{ + IPv6EnabledGroups: []string{"all-group-id", "group-a"}, + } + + assert.True(t, peerWillHaveIPv6(settings, nil, "all-group-id"), "peer in All group should get IPv6") + assert.True(t, peerWillHaveIPv6(settings, []string{"group-a"}, ""), "peer with matching auto-group should get IPv6") + assert.False(t, peerWillHaveIPv6(settings, []string{"group-b"}, "other-all"), "peer with no matching groups should not get IPv6") + assert.False(t, peerWillHaveIPv6(settings, nil, ""), "embedded peer with no groups should not get IPv6") + + emptySettings := &types.Settings{IPv6EnabledGroups: []string{}} + assert.False(t, peerWillHaveIPv6(emptySettings, []string{"group-a"}, "all-group-id"), "no IPv6 groups means no IPv6") +} + +// TestSyncPeer_IPv6CapabilityChangePropagates ensures that when a peer reports +// a new IPv6 overlay capability via SyncPeer (e.g. after a client upgrade or +// flipping --disable-ipv6) without bumping its WtVersion, other account peers +// receive a fresh network map so their AAAA records for it become unstale. +func TestSyncPeer_IPv6CapabilityChangePropagates(t *testing.T) { + manager, updateManager, _, peer1, peer2, _ := setupNetworkMapTest(t) + + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + updateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Drain any initial updates from setup. + drain := func() { + for { + select { + case <-updMsg: + case <-time.After(200 * time.Millisecond): + return + } + } + } + drain() + + t.Run("no propagation when capabilities are unchanged", func(t *testing.T) { + _, _, _, _, err := manager.SyncPeer(context.Background(), types.PeerSync{ + WireGuardPubKey: peer2.Key, + Meta: peer2.Meta, + }, peer2.AccountID) + require.NoError(t, err) + peerShouldNotReceiveUpdate(t, updMsg) + }) + + t.Run("propagation when IPv6 capability is added", func(t *testing.T) { + newMeta := peer2.Meta + newMeta.Capabilities = append([]int32{}, peer2.Meta.Capabilities...) + newMeta.Capabilities = append(newMeta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay) + + _, _, _, _, err := manager.SyncPeer(context.Background(), types.PeerSync{ + WireGuardPubKey: peer2.Key, + Meta: newMeta, + }, peer2.AccountID) + require.NoError(t, err) + peerShouldReceiveUpdate(t, updMsg) + }) +} + func TestUpdatePeer_DnsLabelCollisionWithFQDN(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") diff --git a/management/server/policy_test.go b/management/server/policy_test.go index a553b7d05..1eae07e79 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -3,7 +3,7 @@ package server import ( "context" "fmt" - "net" + "net/netip" "testing" "time" @@ -20,53 +20,53 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), Status: &nbpeer.PeerStatus{}, }, "peerB": { ID: "peerB", - IP: net.ParseIP("100.65.80.39"), + IP: netip.MustParseAddr("100.65.80.39"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"}, }, "peerC": { ID: "peerC", - IP: net.ParseIP("100.65.254.139"), + IP: netip.MustParseAddr("100.65.254.139"), Status: &nbpeer.PeerStatus{}, }, "peerD": { ID: "peerD", - IP: net.ParseIP("100.65.62.5"), + IP: netip.MustParseAddr("100.65.62.5"), Status: &nbpeer.PeerStatus{}, }, "peerE": { ID: "peerE", - IP: net.ParseIP("100.65.32.206"), + IP: netip.MustParseAddr("100.65.32.206"), Status: &nbpeer.PeerStatus{}, }, "peerF": { ID: "peerF", - IP: net.ParseIP("100.65.250.202"), + IP: netip.MustParseAddr("100.65.250.202"), Status: &nbpeer.PeerStatus{}, }, "peerG": { ID: "peerG", - IP: net.ParseIP("100.65.13.186"), + IP: netip.MustParseAddr("100.65.13.186"), Status: &nbpeer.PeerStatus{}, }, "peerH": { ID: "peerH", - IP: net.ParseIP("100.65.29.55"), + IP: netip.MustParseAddr("100.65.29.55"), Status: &nbpeer.PeerStatus{}, }, "peerI": { ID: "peerI", - IP: net.ParseIP("100.65.31.2"), + IP: netip.MustParseAddr("100.65.31.2"), Status: &nbpeer.PeerStatus{}, }, "peerK": { ID: "peerK", - IP: net.ParseIP("100.32.80.1"), + IP: netip.MustParseAddr("100.32.80.1"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"}, }, @@ -540,17 +540,17 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), Status: &nbpeer.PeerStatus{}, }, "peerB": { ID: "peerB", - IP: net.ParseIP("100.65.80.39"), + IP: netip.MustParseAddr("100.65.80.39"), Status: &nbpeer.PeerStatus{}, }, "peerC": { ID: "peerC", - IP: net.ParseIP("100.65.254.139"), + IP: netip.MustParseAddr("100.65.254.139"), Status: &nbpeer.PeerStatus{}, }, }, @@ -746,7 +746,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -756,7 +756,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerB": { ID: "peerB", - IP: net.ParseIP("100.65.80.39"), + IP: netip.MustParseAddr("100.65.80.39"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -766,7 +766,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerC": { ID: "peerC", - IP: net.ParseIP("100.65.254.139"), + IP: netip.MustParseAddr("100.65.254.139"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -776,7 +776,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerD": { ID: "peerD", - IP: net.ParseIP("100.65.62.5"), + IP: netip.MustParseAddr("100.65.62.5"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -786,7 +786,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerE": { ID: "peerE", - IP: net.ParseIP("100.65.32.206"), + IP: netip.MustParseAddr("100.65.32.206"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -796,7 +796,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerF": { ID: "peerF", - IP: net.ParseIP("100.65.250.202"), + IP: netip.MustParseAddr("100.65.250.202"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -806,7 +806,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerG": { ID: "peerG", - IP: net.ParseIP("100.65.13.186"), + IP: netip.MustParseAddr("100.65.13.186"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -816,7 +816,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerH": { ID: "peerH", - IP: net.ParseIP("100.65.29.55"), + IP: netip.MustParseAddr("100.65.29.55"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -826,7 +826,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, "peerI": { ID: "peerI", - IP: net.ParseIP("100.65.21.56"), + IP: netip.MustParseAddr("100.65.21.56"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "windows", diff --git a/management/server/route_test.go b/management/server/route_test.go index d0caf4b9b..79014790f 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2,7 +2,6 @@ package server import ( "context" - "net" "net/netip" "testing" "time" @@ -1333,14 +1332,24 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou return nil, err } + v6Prefix, err := netip.ParsePrefix(account.Network.NetV6.String()) + if err != nil { + return nil, err + } + ips := account.GetTakenIPs() - peer1IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer1IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer1IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer1 := &nbpeer.Peer{ IP: peer1IP, + IPv6: peer1IPv6, ID: peer1ID, Key: peer1Key, Name: "test-host1@netbird.io", @@ -1361,13 +1370,18 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou account.Peers[peer1.ID] = peer1 ips = account.GetTakenIPs() - peer2IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer2IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer2IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer2 := &nbpeer.Peer{ IP: peer2IP, + IPv6: peer2IPv6, ID: peer2ID, Key: peer2Key, Name: "test-host2@netbird.io", @@ -1388,13 +1402,18 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou account.Peers[peer2.ID] = peer2 ips = account.GetTakenIPs() - peer3IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer3IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer3IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer3 := &nbpeer.Peer{ IP: peer3IP, + IPv6: peer3IPv6, ID: peer3ID, Key: peer3Key, Name: "test-host3@netbird.io", @@ -1415,13 +1434,18 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou account.Peers[peer3.ID] = peer3 ips = account.GetTakenIPs() - peer4IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer4IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer4IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer4 := &nbpeer.Peer{ IP: peer4IP, + IPv6: peer4IPv6, ID: peer4ID, Key: peer4Key, Name: "test-host4@netbird.io", @@ -1442,13 +1466,18 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou account.Peers[peer4.ID] = peer4 ips = account.GetTakenIPs() - peer5IP, err := types.AllocatePeerIP(account.Network.Net, ips) + peer5IP, err := types.AllocatePeerIP(netip.MustParsePrefix(account.Network.Net.String()), ips) + if err != nil { + return nil, err + } + peer5IPv6, err := types.AllocateRandomPeerIPv6(v6Prefix) if err != nil { return nil, err } peer5 := &nbpeer.Peer{ IP: peer5IP, + IPv6: peer5IPv6, ID: peer5ID, Key: peer5Key, Name: "test-host5@netbird.io", @@ -1549,7 +1578,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -1557,18 +1587,21 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, "peerB": { ID: "peerB", - IP: net.ParseIP(peerBIp), + IP: netip.MustParseAddr(peerBIp), + IPv6: netip.MustParseAddr("fd00::2"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{}, }, "peerC": { ID: "peerC", - IP: net.ParseIP(peerCIp), + IP: netip.MustParseAddr(peerCIp), + IPv6: netip.MustParseAddr("fd00::3"), Status: &nbpeer.PeerStatus{}, }, "peerD": { ID: "peerD", - IP: net.ParseIP("100.65.62.5"), + IP: netip.MustParseAddr("100.65.62.5"), + IPv6: netip.MustParseAddr("fd00::4"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ GoOS: "linux", @@ -1576,7 +1609,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, "peerE": { ID: "peerE", - IP: net.ParseIP("100.65.32.206"), + IP: netip.MustParseAddr("100.65.32.206"), + IPv6: netip.MustParseAddr("fd00::5"), Key: peer1Key, Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ @@ -1585,27 +1619,32 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, "peerF": { ID: "peerF", - IP: net.ParseIP("100.65.250.202"), + IP: netip.MustParseAddr("100.65.250.202"), + IPv6: netip.MustParseAddr("fd00::6"), Status: &nbpeer.PeerStatus{}, }, "peerG": { ID: "peerG", - IP: net.ParseIP("100.65.13.186"), + IP: netip.MustParseAddr("100.65.13.186"), + IPv6: netip.MustParseAddr("fd00::7"), Status: &nbpeer.PeerStatus{}, }, "peerH": { ID: "peerH", - IP: net.ParseIP(peerHIp), + IP: netip.MustParseAddr(peerHIp), + IPv6: netip.MustParseAddr("fd00::8"), Status: &nbpeer.PeerStatus{}, }, "peerJ": { ID: "peerJ", - IP: net.ParseIP(peerJIp), + IP: netip.MustParseAddr(peerJIp), + IPv6: netip.MustParseAddr("fd00::a"), Status: &nbpeer.PeerStatus{}, }, "peerK": { ID: "peerK", - IP: net.ParseIP(peerKIp), + IP: netip.MustParseAddr(peerKIp), + IPv6: netip.MustParseAddr("fd00::b"), Status: &nbpeer.PeerStatus{}, }, }, @@ -2129,84 +2168,101 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", - IP: net.ParseIP("100.65.14.88"), + IP: netip.MustParseAddr("100.65.14.88"), + IPv6: netip.MustParseAddr("fd00::1"), Key: "peerA", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", + GoOS: "linux", + Capabilities: []int32{nbpeer.PeerCapabilityIPv6Overlay}, }, }, "peerB": { ID: "peerB", - IP: net.ParseIP(peerBIp), + IP: netip.MustParseAddr(peerBIp), + IPv6: netip.MustParseAddr("fd00::2"), Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{}, }, "peerC": { ID: "peerC", - IP: net.ParseIP(peerCIp), + IP: netip.MustParseAddr(peerCIp), + IPv6: netip.MustParseAddr("fd00::3"), Status: &nbpeer.PeerStatus{}, }, "peerD": { ID: "peerD", - IP: net.ParseIP("100.65.62.5"), + IP: netip.MustParseAddr("100.65.62.5"), + IPv6: netip.MustParseAddr("fd00::4"), Key: "peerD", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", + GoOS: "linux", + Capabilities: []int32{nbpeer.PeerCapabilityIPv6Overlay}, }, }, "peerE": { ID: "peerE", - IP: net.ParseIP("100.65.32.206"), + IP: netip.MustParseAddr("100.65.32.206"), + IPv6: netip.MustParseAddr("fd00::5"), Key: "peerE", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", + GoOS: "linux", + Capabilities: []int32{nbpeer.PeerCapabilityIPv6Overlay}, }, }, "peerF": { ID: "peerF", - IP: net.ParseIP("100.65.250.202"), + IP: netip.MustParseAddr("100.65.250.202"), + IPv6: netip.MustParseAddr("fd00::6"), Status: &nbpeer.PeerStatus{}, }, "peerG": { ID: "peerG", - IP: net.ParseIP("100.65.13.186"), + IP: netip.MustParseAddr("100.65.13.186"), + IPv6: netip.MustParseAddr("fd00::7"), Status: &nbpeer.PeerStatus{}, }, "peerH": { ID: "peerH", - IP: net.ParseIP(peerHIp), + IP: netip.MustParseAddr(peerHIp), + IPv6: netip.MustParseAddr("fd00::8"), Status: &nbpeer.PeerStatus{}, }, "peerJ": { ID: "peerJ", - IP: net.ParseIP(peerJIp), + IP: netip.MustParseAddr(peerJIp), + IPv6: netip.MustParseAddr("fd00::a"), Status: &nbpeer.PeerStatus{}, }, "peerK": { ID: "peerK", - IP: net.ParseIP(peerKIp), + IP: netip.MustParseAddr(peerKIp), + IPv6: netip.MustParseAddr("fd00::b"), Status: &nbpeer.PeerStatus{}, }, "peerL": { ID: "peerL", - IP: net.ParseIP("100.65.19.186"), + IP: netip.MustParseAddr("100.65.19.186"), + IPv6: netip.MustParseAddr("fd00::d"), Key: "peerL", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", + GoOS: "linux", + Capabilities: []int32{nbpeer.PeerCapabilityIPv6Overlay}, }, }, "peerM": { ID: "peerM", - IP: net.ParseIP(peerMIp), + IP: netip.MustParseAddr(peerMIp), + IPv6: netip.MustParseAddr("fd00::e"), Status: &nbpeer.PeerStatus{}, }, "peerN": { ID: "peerN", - IP: net.ParseIP("100.65.20.18"), + IP: netip.MustParseAddr("100.65.20.18"), + IPv6: netip.MustParseAddr("fd00::f"), Key: "peerN", Status: &nbpeer.PeerStatus{}, Meta: nbpeer.PeerSystemMeta{ @@ -2215,7 +2271,8 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { }, "peerO": { ID: "peerO", - IP: net.ParseIP(peerOIp), + IP: netip.MustParseAddr(peerOIp), + IPv6: netip.MustParseAddr("fd00::10"), Status: &nbpeer.PeerStatus{}, }, }, diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 74af0a3ef..345d857f9 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -5,6 +5,7 @@ package settings import ( "context" "fmt" + "net/netip" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" @@ -22,6 +23,9 @@ type Manager interface { GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) GetExtraSettings(ctx context.Context, accountID string) (*types.ExtraSettings, error) UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error) + // GetEffectiveNetworkRanges returns the actual allocated network ranges (v4 and v6). + // This includes auto-allocated ranges even when no custom override was set. + GetEffectiveNetworkRanges(ctx context.Context, accountID string) (v4, v6 netip.Prefix, err error) } // IdpConfig holds IdP-related configuration that is set at runtime @@ -115,3 +119,28 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (* func (m *managerImpl) UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error) { return m.extraSettingsManager.UpdateExtraSettings(ctx, accountID, userID, extraSettings) } + +// GetEffectiveNetworkRanges returns the actual allocated network ranges from the account's network object. +func (m *managerImpl) GetEffectiveNetworkRanges(ctx context.Context, accountID string) (netip.Prefix, netip.Prefix, error) { + network, err := m.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return netip.Prefix{}, netip.Prefix{}, fmt.Errorf("get account network: %w", err) + } + + var v4, v6 netip.Prefix + if network.Net.IP != nil { + addr, ok := netip.AddrFromSlice(network.Net.IP) + if ok { + ones, _ := network.Net.Mask.Size() + v4 = netip.PrefixFrom(addr.Unmap(), ones) + } + } + if network.NetV6.IP != nil { + addr, ok := netip.AddrFromSlice(network.NetV6.IP) + if ok { + ones, _ := network.NetV6.Mask.Size() + v6 = netip.PrefixFrom(addr.Unmap(), ones) + } + } + return v4, v6, nil +} diff --git a/management/server/settings/manager_mock.go b/management/server/settings/manager_mock.go index dc2f2ebfe..4bedb2cf7 100644 --- a/management/server/settings/manager_mock.go +++ b/management/server/settings/manager_mock.go @@ -6,6 +6,7 @@ package settings import ( context "context" + netip "net/netip" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -94,3 +95,19 @@ func (mr *MockManagerMockRecorder) UpdateExtraSettings(ctx, accountID, userID, e mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateExtraSettings", reflect.TypeOf((*MockManager)(nil).UpdateExtraSettings), ctx, accountID, userID, extraSettings) } + +// GetEffectiveNetworkRanges mocks base method. +func (m *MockManager) GetEffectiveNetworkRanges(ctx context.Context, accountID string) (netip.Prefix, netip.Prefix, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEffectiveNetworkRanges", ctx, accountID) + ret0, _ := ret[0].(netip.Prefix) + ret1, _ := ret[1].(netip.Prefix) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetEffectiveNetworkRanges indicates an expected call of GetEffectiveNetworkRanges. +func (mr *MockManagerMockRecorder) GetEffectiveNetworkRanges(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEffectiveNetworkRanges", reflect.TypeOf((*MockManager)(nil).GetEffectiveNetworkRanges), ctx, accountID) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 1fa3d08ee..973101ce3 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net" + "net/netip" "os" "path/filepath" "runtime" @@ -1503,7 +1504,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc SELECT id, created_by, created_at, domain, domain_category, is_domain_primary_account, -- Embedded Network - network_identifier, network_net, network_dns, network_serial, + network_identifier, network_net, network_net_v6, network_dns, network_serial, -- Embedded DNSSettings dns_settings_disabled_management_groups, -- Embedded Settings @@ -1512,7 +1513,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc settings_regular_users_view_blocked, settings_groups_propagation_enabled, settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, - settings_lazy_connection_enabled, + settings_network_range_v6, settings_ipv6_enabled_groups, settings_lazy_connection_enabled, -- Embedded ExtraSettings settings_extra_peer_approval_enabled, settings_extra_user_approval_required, settings_extra_integrated_validator, settings_extra_integrated_validator_groups @@ -1531,12 +1532,15 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc sRoutingPeerDNSResolutionEnabled sql.NullBool sDNSDomain sql.NullString sNetworkRange sql.NullString + sNetworkRangeV6 sql.NullString + sIPv6EnabledGroups sql.NullString sLazyConnectionEnabled sql.NullBool sExtraPeerApprovalEnabled sql.NullBool sExtraUserApprovalRequired sql.NullBool sExtraIntegratedValidator sql.NullString sExtraIntegratedValidatorGroups sql.NullString networkNet sql.NullString + networkNetV6 sql.NullString dnsSettingsDisabledGroups sql.NullString networkIdentifier sql.NullString networkDns sql.NullString @@ -1545,14 +1549,14 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc ) err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( &account.Id, &account.CreatedBy, &createdAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, - &networkIdentifier, &networkNet, &networkDns, &networkSerial, + &networkIdentifier, &networkNet, &networkNetV6, &networkDns, &networkSerial, &dnsSettingsDisabledGroups, &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, - &sLazyConnectionEnabled, + &sNetworkRangeV6, &sIPv6EnabledGroups, &sLazyConnectionEnabled, &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, ) @@ -1621,6 +1625,15 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc if sNetworkRange.Valid { _ = json.Unmarshal([]byte(sNetworkRange.String), &account.Settings.NetworkRange) } + if networkNetV6.Valid { + _ = json.Unmarshal([]byte(networkNetV6.String), &account.Network.NetV6) + } + if sNetworkRangeV6.Valid { + _ = json.Unmarshal([]byte(sNetworkRangeV6.String), &account.Settings.NetworkRangeV6) + } + if sIPv6EnabledGroups.Valid { + _ = json.Unmarshal([]byte(sIPv6EnabledGroups.String), &account.Settings.IPv6EnabledGroups) + } if sExtraPeerApprovalEnabled.Valid { account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool @@ -1702,12 +1715,12 @@ func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types. func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, - inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, - meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, + inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, + meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, - meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, - peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, - location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster FROM peers WHERE account_id = $1` + meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired, + peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, + location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -1721,7 +1734,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool peerStatusLastSeen sql.NullTime peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool - ip, extraDNS, netAddr, env, flags, files, connIP []byte + ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString @@ -1733,9 +1746,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &createdAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, - &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, + &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, - &locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster) + &locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6) if err == nil { if lastLogin.Valid { @@ -1828,6 +1841,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee if ip != nil { _ = json.Unmarshal(ip, &p.IP) } + if ipv6 != nil { + _ = json.Unmarshal(ipv6, &p.IPv6) + } if extraDNS != nil { _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) } @@ -1843,6 +1859,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee if files != nil { _ = json.Unmarshal(files, &p.Meta.Files) } + if capabilities != nil { + _ = json.Unmarshal(capabilities, &p.Meta.Capabilities) + } if connIP != nil { _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) } @@ -2586,7 +2605,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) return accountID, nil } -func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { +func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]netip.Addr, error) { tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -2594,7 +2613,6 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength var ipJSONStrings []string - // Fetch the IP addresses as JSON strings result := tx.Model(&nbpeer.Peer{}). Where("account_id = ?", accountID). Pluck("ip", &ipJSONStrings) @@ -2605,14 +2623,13 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } - // Convert the JSON strings to net.IP objects - ips := make([]net.IP, len(ipJSONStrings)) + ips := make([]netip.Addr, len(ipJSONStrings)) for i, ipJSON := range ipJSONStrings { - var ip net.IP + var ip netip.Addr if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil { return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store") } - ips[i] = ip + ips[i] = ip.Unmap() } return ips, nil @@ -3214,7 +3231,7 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre query = query.Where("name LIKE ?", "%"+nameFilter+"%") } if ipFilter != "" { - query = query.Where("ip LIKE ?", "%"+ipFilter+"%") + query = query.Where("ip LIKE ? OR ipv6 LIKE ?", "%"+ipFilter+"%", "%"+ipFilter+"%") } if err := query.Find(&peers).Error; err != nil { @@ -4090,9 +4107,10 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, accountID string, se return status.Errorf(status.Internal, "failed to save account settings to store") } - if result.RowsAffected == 0 { - return status.NewAccountNotFoundError(accountID) - } + // MySQL reports RowsAffected=0 for no-op updates where values don't change, + // unlike SQLite/Postgres which report matched rows. Skip the check since the + // caller (UpdateAccountSettings) already verified the account exists via + // GetAccountSettings with LockingStrengthUpdate. return nil } @@ -4517,11 +4535,15 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } + column := "ip" + if ip.To4() == nil { + column = "ipv6" + } jsonValue := fmt.Sprintf(`"%s"`, ip.String()) var peer nbpeer.Peer result := tx. - Take(&peer, "account_id = ? AND ip = ?", accountID, jsonValue) + Take(&peer, fmt.Sprintf("account_id = ? AND %s = ?", column), accountID, jsonValue) if result.Error != nil { // no logging here return nil, status.Errorf(status.Internal, "failed to get peer from store") @@ -4643,6 +4665,27 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i return nil } +// UpdateAccountNetworkV6 updates the IPv6 network range for the account. +func (s *SqlStore) UpdateAccountNetworkV6(ctx context.Context, accountID string, ipNet net.IPNet) error { + patch := accountNetworkPatch{ + Network: &types.Network{NetV6: ipNet}, + } + + result := s.db. + Model(&types.Account{}). + Where(idQueryCondition, accountID). + Updates(&patch) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to update account network v6: %v", result.Error) + return status.Errorf(status.Internal, "update account network v6") + } + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + return nil +} + func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) { if len(groupIDs) == 0 { return []*nbpeer.Peer{}, nil diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go index 69e346ae7..9a9de8cdd 100644 --- a/management/server/store/sql_store_get_account_test.go +++ b/management/server/store/sql_store_get_account_test.go @@ -148,7 +148,8 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { AccountID: accountID, Key: "peer-key-1-AAAA", Name: "Peer 1", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{ Hostname: "peer1.example.com", GoOS: "linux", @@ -195,7 +196,8 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { AccountID: accountID, Key: "peer-key-2-BBBB", Name: "Peer 2", - IP: net.ParseIP("100.64.0.2"), + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00::2"), Meta: nbpeer.PeerSystemMeta{ Hostname: "peer2.example.com", GoOS: "darwin", @@ -232,7 +234,8 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { AccountID: accountID, Key: "peer-key-3-CCCC", Name: "Peer 3 (Ephemeral)", - IP: net.ParseIP("100.64.0.3"), + IP: netip.MustParseAddr("100.64.0.3"), + IPv6: netip.MustParseAddr("fd00::3"), Meta: nbpeer.PeerSystemMeta{ Hostname: "peer3.example.com", GoOS: "windows", @@ -710,7 +713,7 @@ func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { require.True(t, exists, "Peer 1 should exist") assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch") assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch") - assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch") + assert.Equal(t, netip.MustParseAddr("100.64.0.1"), p1.IP, "Peer 1 IP mismatch") assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch") assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled") assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch") diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 5a5616abc..2819265c3 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -94,11 +94,12 @@ func runLargeTest(t *testing.T, store Store) { for n := 0; n < numPerAccount; n++ { netIP := randomIPv4() peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) + addr, _ := netip.AddrFromSlice(netIP) peer := &nbpeer.Peer{ ID: peerID, Key: peerID, - IP: netIP, + IP: addr.Unmap(), Name: peerID, DNSLabel: peerID, UserID: "testuser", @@ -235,7 +236,8 @@ func Test_SaveAccount(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -249,7 +251,8 @@ func Test_SaveAccount(t *testing.T) { account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 2}), + IPv6: netip.MustParseAddr("fd00::2"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name 2", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -316,7 +319,8 @@ func TestSqlite_DeleteAccount(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -499,7 +503,8 @@ func TestSqlStore_SavePeer(t *testing.T) { peer := &nbpeer.Peer{ Key: "peerkey", ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -556,7 +561,8 @@ func TestSqlStore_SavePeerStatus(t *testing.T) { account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -784,7 +790,8 @@ func newAccount(store Store, id int) error { account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ Key: "peerkey" + str, - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -823,7 +830,8 @@ func TestPostgresql_SaveAccount(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -837,7 +845,8 @@ func TestPostgresql_SaveAccount(t *testing.T) { account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 2}), + IPv6: netip.MustParseAddr("fd00::2"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name 2", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -903,7 +912,8 @@ func TestPostgresql_DeleteAccount(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{127, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::1"), Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, @@ -1010,37 +1020,39 @@ func TestSqlite_GetTakenIPs(t *testing.T) { takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) - assert.Equal(t, []net.IP{}, takenIPs) + assert.Equal(t, []netip.Addr{}, takenIPs) peer1 := &nbpeer.Peer{ ID: "peer1", AccountID: existingAccountID, Key: "key1", DNSLabel: "peer1", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1:1:1:1"), } err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) - ip1 := net.IP{1, 1, 1, 1}.To16() - assert.Equal(t, []net.IP{ip1}, takenIPs) + ip1 := netip.AddrFrom4([4]byte{1, 1, 1, 1}) + assert.Equal(t, []netip.Addr{ip1}, takenIPs) peer2 := &nbpeer.Peer{ ID: "peer1second", AccountID: existingAccountID, Key: "key2", DNSLabel: "peer1-1", - IP: net.IP{2, 2, 2, 2}, + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2:2:2:2"), } err = store.AddPeerToAccount(context.Background(), peer2) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) - ip2 := net.IP{2, 2, 2, 2}.To16() - assert.Equal(t, []net.IP{ip1, ip2}, takenIPs) + ip2 := netip.AddrFrom4([4]byte{2, 2, 2, 2}) + assert.Equal(t, []netip.Addr{ip1, ip2}, takenIPs) } func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { @@ -1060,7 +1072,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, Key: "key1", DNSLabel: "peer1", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1:1:1:1"), } err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) @@ -1074,7 +1087,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, Key: "key2", DNSLabel: "peer1-1", - IP: net.IP{2, 2, 2, 2}, + IP: netip.AddrFrom4([4]byte{2, 2, 2, 2}), + IPv6: netip.MustParseAddr("fd00::2:2:2:2"), } err = store.AddPeerToAccount(context.Background(), peer2) require.NoError(t, err) @@ -1127,7 +1141,8 @@ func Test_AddPeerWithSameIP(t *testing.T) { ID: "peer1", AccountID: existingAccountID, Key: "key1", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1:1:1:1"), } err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) @@ -1136,7 +1151,8 @@ func Test_AddPeerWithSameIP(t *testing.T) { ID: "peer1second", AccountID: existingAccountID, Key: "key2", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::2:2:2:2"), } err = store.AddPeerToAccount(context.Background(), peer2) require.Error(t, err) @@ -2640,7 +2656,8 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { ID: "peer1", AccountID: accountID, Key: "key", - IP: net.IP{1, 1, 1, 1}, + IP: netip.AddrFrom4([4]byte{1, 1, 1, 1}), + IPv6: netip.MustParseAddr("fd00::1:1:1:1"), Meta: nbpeer.PeerSystemMeta{ Hostname: "hostname", GoOS: "linux", @@ -3815,10 +3832,10 @@ func BenchmarkGetAccountPeers(b *testing.B) { } } -func intToIPv4(n uint32) net.IP { - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, n) - return ip +func intToIPv4(n uint32) netip.Addr { + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + return netip.AddrFrom4(b) } func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { @@ -3945,7 +3962,8 @@ func TestSqlStore_GetUserIDByPeerKey(t *testing.T) { Key: peerKey, AccountID: existingAccountID, UserID: userID, - IP: net.IP{10, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{10, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::a00:1"), DNSLabel: "test-peer-1", } @@ -3982,7 +4000,8 @@ func TestSqlStore_GetUserIDByPeerKey_NoUserID(t *testing.T) { Key: peerKey, AccountID: existingAccountID, UserID: "", - IP: net.IP{10, 0, 0, 1}, + IP: netip.AddrFrom4([4]byte{10, 0, 0, 1}), + IPv6: netip.MustParseAddr("fd00::a00:1"), DNSLabel: "test-peer-1", } @@ -4009,7 +4028,8 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) { AccountID: accountID, DNSLabel: "peer1.netbird.cloud", Key: "peer1-key", - IP: net.ParseIP("100.64.0.1"), + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), Status: &nbpeer.PeerStatus{ RequiresApproval: true, LastSeen: time.Now().UTC(), @@ -4020,7 +4040,8 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) { AccountID: accountID, DNSLabel: "peer2.netbird.cloud", Key: "peer2-key", - IP: net.ParseIP("100.64.0.2"), + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00::2"), Status: &nbpeer.PeerStatus{ RequiresApproval: true, LastSeen: time.Now().UTC(), @@ -4031,7 +4052,8 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) { AccountID: accountID, DNSLabel: "peer3.netbird.cloud", Key: "peer3-key", - IP: net.ParseIP("100.64.0.3"), + IP: netip.MustParseAddr("100.64.0.3"), + IPv6: netip.MustParseAddr("fd00::3"), Status: &nbpeer.PeerStatus{ RequiresApproval: false, LastSeen: time.Now().UTC(), diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index 81c4b33ae..a38b4a8c1 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -344,7 +344,8 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { ID: fmt.Sprintf("peer-%d", i), AccountID: accountID, Key: fmt.Sprintf("peerkey-%d", i), - IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)), + IP: netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", i+1)), + IPv6: netip.MustParseAddr(fmt.Sprintf("fd00::%d", i+1)), Name: fmt.Sprintf("peer-name-%d", i), Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()}, }) diff --git a/management/server/store/store.go b/management/server/store/store.go index 447c85547..db98bc644 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -185,7 +185,7 @@ type Store interface { SaveNameServerGroup(ctx context.Context, nameServerGroup *dns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nameServerGroupID string) error - GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) + GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]netip.Addr, error) IncrementNetworkSerial(ctx context.Context, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) @@ -225,6 +225,7 @@ type Store interface { IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error + UpdateAccountNetworkV6(ctx context.Context, accountID string, ipNet net.IPNet) error GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error) // SetFieldEncrypt sets the field encryptor for encrypting sensitive user data. diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index d8bd826a8..6c2c9bbc3 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -7,6 +7,7 @@ package store import ( context "context" net "net" + netip "net/netip" reflect "reflect" time "time" @@ -2138,10 +2139,10 @@ func (mr *MockStoreMockRecorder) GetStoreEngine() *gomock.Call { } // GetTakenIPs mocks base method. -func (m *MockStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) { +func (m *MockStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]netip.Addr, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetTakenIPs", ctx, lockStrength, accountId) - ret0, _ := ret[0].([]net.IP) + ret0, _ := ret[0].([]netip.Addr) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2952,6 +2953,20 @@ func (mr *MockStoreMockRecorder) UpdateAccountNetwork(ctx, accountID, ipNet inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountNetwork", reflect.TypeOf((*MockStore)(nil).UpdateAccountNetwork), ctx, accountID, ipNet) } +// UpdateAccountNetworkV6 mocks base method. +func (m *MockStore) UpdateAccountNetworkV6(ctx context.Context, accountID string, ipNet net.IPNet) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountNetworkV6", ctx, accountID, ipNet) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountNetworkV6 indicates an expected call of UpdateAccountNetworkV6. +func (mr *MockStoreMockRecorder) UpdateAccountNetworkV6(ctx, accountID, ipNet interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountNetworkV6", reflect.TypeOf((*MockStore)(nil).UpdateAccountNetworkV6), ctx, accountID, ipNet) +} + // UpdateCustomDomain mocks base method. func (m *MockStore) UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error) { m.ctrl.T.Helper() diff --git a/management/server/types/account.go b/management/server/types/account.go index e7c1e2dce..49600163a 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -3,7 +3,6 @@ package types import ( "context" "fmt" - "net" "net/netip" "slices" "strconv" @@ -270,6 +269,8 @@ func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdn domainSuffix := "." + dnsDomain + ipv6AllowedPeers := a.peerIPv6AllowedSet() + var sb strings.Builder for _, peer := range a.Peers { if peer.DNSLabel == "" { @@ -281,13 +282,31 @@ func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdn sb.WriteString(peer.DNSLabel) sb.WriteString(domainSuffix) + fqdn := sb.String() customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ - Name: sb.String(), + Name: fqdn, Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: defaultTTL, RData: peer.IP.String(), }) + // Only advertise AAAA for peers that have a valid IPv6, whose client supports it, + // and that belong to an IPv6-enabled group. Old clients don't configure v6 on their + // WireGuard interface, so resolving their AAAA causes connections to hang. + // Capability changes (client upgrade/downgrade, --disable-ipv6 toggle) propagate + // to other peers via SyncPeer/LoginPeer regardless of version change, so AAAA + // records refresh when a peer first reports the IPv6 overlay capability. + _, peerAllowed := ipv6AllowedPeers[peer.ID] + hasIPv6 := peer.IPv6.IsValid() && peer.SupportsIPv6() && peerAllowed + if hasIPv6 { + customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ + Name: fqdn, + Type: int(dns.TypeAAAA), + Class: nbdns.DefaultClass, + TTL: defaultTTL, + RData: peer.IPv6.String(), + }) + } sb.Reset() for _, extraLabel := range peer.ExtraDNSLabels { @@ -295,13 +314,23 @@ func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdn sb.WriteString(extraLabel) sb.WriteString(domainSuffix) + extraFqdn := sb.String() customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ - Name: sb.String(), + Name: extraFqdn, Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: defaultTTL, RData: peer.IP.String(), }) + if hasIPv6 { + customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ + Name: extraFqdn, + Type: int(dns.TypeAAAA), + Class: nbdns.DefaultClass, + TTL: defaultTTL, + RData: peer.IPv6.String(), + }) + } sb.Reset() } @@ -569,8 +598,43 @@ func (a *Account) GetPeerGroups(peerID string) LookupMap { return groupList } -func (a *Account) GetTakenIPs() []net.IP { - var takenIps []net.IP +// PeerIPv6Allowed reports whether the given peer is in any of the account's IPv6 enabled groups. +// Returns false if IPv6 is disabled or no groups are configured. +func (a *Account) PeerIPv6Allowed(peerID string) bool { + if len(a.Settings.IPv6EnabledGroups) == 0 { + return false + } + + for _, groupID := range a.Settings.IPv6EnabledGroups { + group, ok := a.Groups[groupID] + if !ok { + continue + } + if slices.Contains(group.Peers, peerID) { + return true + } + } + return false +} + +// peerIPv6AllowedSet returns a set of peer IDs that belong to any IPv6-enabled group. +func (a *Account) peerIPv6AllowedSet() map[string]struct{} { + result := make(map[string]struct{}) + for _, groupID := range a.Settings.IPv6EnabledGroups { + group, ok := a.Groups[groupID] + if !ok { + continue + } + for _, peerID := range group.Peers { + result[peerID] = struct{}{} + } + } + return result +} + +// GetTakenIPs returns all peer IP addresses currently allocated in the account. +func (a *Account) GetTakenIPs() []netip.Addr { + takenIps := make([]netip.Addr, 0, len(a.Peers)) for _, existingPeer := range a.Peers { takenIps = append(takenIps, existingPeer.IP) } @@ -927,10 +991,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { rules = append(rules, &fr) - continue + } else { + rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) } - rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) + rules = appendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, firewallRuleContext{ + direction: direction, + dirStr: strconv.Itoa(direction), + protocolStr: string(protocol), + actionStr: string(rule.Action), + portsJoined: strings.Join(rule.Ports, ","), + }) } }, func() ([]*nbpeer.Peer, []*FirewallRule) { return peers, rules @@ -1045,7 +1116,7 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks { return nil } -func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}, includeIPv6 bool) []*RouteFirewallRule { var fwRules []*RouteFirewallRule for _, policy := range policies { if !policy.Enabled { @@ -1058,7 +1129,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli } rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6) fwRules = append(fwRules, rules...) } } @@ -1140,7 +1211,7 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())] distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) - rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) + rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers, peer.SupportsIPv6() && peer.IPv6.IsValid()) for _, rule := range rules { if len(rule.SourceRanges) > 0 { routesFirewallRules = append(routesFirewallRules, rule) @@ -1595,24 +1666,32 @@ func peerSupportedFirewallFeatures(peerVer string) supportedFeatures { } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. +// AAAA records are excluded when the requesting peer lacks IPv6 capability. func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) - peerIPs := make(map[string]struct{}) + peerIPs := make(map[netip.Addr]struct{}, len(peersToConnect)+len(expiredPeers)+2) + includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid() - // Add peer's own IP to include its own DNS records - peerIPs[peer.IP.String()] = struct{}{} - - for _, peerToConnect := range peersToConnect { - peerIPs[peerToConnect.IP.String()] = struct{}{} + addPeerIPs := func(p *nbpeer.Peer) { + peerIPs[p.IP] = struct{}{} + if includeIPv6 && p.IPv6.IsValid() { + peerIPs[p.IPv6] = struct{}{} + } } - for _, expiredPeer := range expiredPeers { - peerIPs[expiredPeer.IP.String()] = struct{}{} + addPeerIPs(peer) + for _, p := range peersToConnect { + addPeerIPs(p) + } + for _, p := range expiredPeers { + addPeerIPs(p) } for _, record := range customZone.Records { - if _, exists := peerIPs[record.RData]; exists { - filteredRecords = append(filteredRecords, record) + if addr, err := netip.ParseAddr(record.RData); err == nil { + if _, exists := peerIPs[addr.Unmap()]; exists { + filteredRecords = append(filteredRecords, record) + } } } diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go index bd4244546..2b4f7e051 100644 --- a/management/server/types/account_components.go +++ b/management/server/types/account_components.go @@ -115,7 +115,7 @@ func (a *Account) GetPeerNetworkMapComponents( components.Groups = relevantGroups components.Policies = relevantPolicies components.Routes = relevantRoutes - components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers) + components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers, peer.SupportsIPv6() && peer.IPv6.IsValid()) peerGroups := a.GetPeerGroups(peerID) components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups) @@ -539,15 +539,22 @@ func filterPostureFailedPeers(postureFailedPeers *map[string]map[string]struct{} } } -func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer) []nbdns.SimpleRecord { +func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer, includeIPv6 bool) []nbdns.SimpleRecord { if len(records) == 0 || len(peers) == 0 { return nil } - peerIPs := make(map[string]struct{}, len(peers)) + // Include both v4 and v6 addresses so AAAA records (whose RData is an IPv6 + // address) are not filtered out when peers have IPv6 assigned. When the + // requesting peer doesn't have IPv6, omit v6 IPs so AAAA records get dropped. + peerIPs := make(map[string]struct{}, len(peers)*2) for _, peer := range peers { - if peer != nil { - peerIPs[peer.IP.String()] = struct{}{} + if peer == nil { + continue + } + peerIPs[peer.IP.String()] = struct{}{} + if includeIPv6 && peer.IPv6.IsValid() { + peerIPs[peer.IPv6.String()] = struct{}{} } } diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index 9b1c9e31d..a1a616882 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -3,7 +3,7 @@ package types import ( "context" "fmt" - "net" + "net/netip" "testing" "github.com/miekg/dns" @@ -921,7 +921,11 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, peersToConnect: []*nbpeer.Peer{}, expiredPeers: []*nbpeer.Peer{}, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + peer: &nbpeer.Peer{ + ID: "router", + IP: netip.MustParseAddr("10.0.0.100"), + IPv6: netip.MustParseAddr("fd00::a00:64"), + }, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, @@ -948,14 +952,19 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { var peers []*nbpeer.Peer for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { peers = append(peers, &nbpeer.Peer{ - ID: fmt.Sprintf("peer%d", i), - IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + ID: fmt.Sprintf("peer%d", i), + IP: netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + IPv6: netip.MustParseAddr(fmt.Sprintf("fd00::%d", i)), }) } return peers }(), expiredPeers: []*nbpeer.Peer{}, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + peer: &nbpeer.Peer{ + ID: "router", + IP: netip.MustParseAddr("10.0.0.100"), + IPv6: netip.MustParseAddr("fd00::a00:64"), + }, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -986,11 +995,27 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{ - {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, - {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, + { + ID: "peer1", + IP: netip.MustParseAddr("10.0.0.1"), + IPv6: netip.MustParseAddr("fd00::a00:1"), + DNSLabel: "peer1", + ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}, + }, + { + ID: "peer2", + IP: netip.MustParseAddr("10.0.0.2"), + IPv6: netip.MustParseAddr("fd00::a00:2"), + DNSLabel: "peer2", + ExtraDNSLabels: []string{"peer2-service"}, + }, }, expiredPeers: []*nbpeer.Peer{}, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + peer: &nbpeer.Peer{ + ID: "router", + IP: netip.MustParseAddr("10.0.0.100"), + IPv6: netip.MustParseAddr("fd00::a00:64"), + }, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -1012,12 +1037,24 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{ - {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + { + ID: "peer1", + IP: netip.MustParseAddr("10.0.0.1"), + IPv6: netip.MustParseAddr("fd00::a00:1"), + }, }, expiredPeers: []*nbpeer.Peer{ - {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + { + ID: "expired-peer", + IP: netip.MustParseAddr("10.0.0.99"), + IPv6: netip.MustParseAddr("fd00::a00:63"), + }, + }, + peer: &nbpeer.Peer{ + ID: "router", + IP: netip.MustParseAddr("10.0.0.100"), + IPv6: netip.MustParseAddr("fd00::a00:64"), }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go index 19222a607..b76a94290 100644 --- a/management/server/types/firewall_rule.go +++ b/management/server/types/firewall_rule.go @@ -48,16 +48,26 @@ func (r *FirewallRule) Equal(other *FirewallRule) bool { } // generateRouteFirewallRules generates a list of firewall rules for a given route. -func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { +// For static routes, source ranges match the destination family (v4 or v6). +// For dynamic routes (domain-based), separate v4 and v6 rules are generated +// so the routing peer's forwarding chain allows both address families. +func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, includeIPv6 bool) []*RouteFirewallRule { rulesExists := make(map[string]struct{}) rules := make([]*RouteFirewallRule, 0) - sourceRanges := make([]string, 0, len(groupPeers)) - for _, peer := range groupPeers { - if peer == nil { - continue - } - sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + v4Sources, v6Sources := splitPeerSourcesByFamily(groupPeers) + + isV6Route := route.Network.Addr().Is6() + + // Skip v6 destination routes entirely for peers without IPv6 support + if isV6Route && !includeIPv6 { + return rules + } + + // Pick sources matching the destination family + sourceRanges := v4Sources + if isV6Route { + sourceRanges = v6Sources } baseRule := RouteFirewallRule{ @@ -71,18 +81,47 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule IsDynamic: route.IsDynamic(), } - // generate rule for port range if len(rule.Ports) == 0 { rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) } else { rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) } - // TODO: generate IPv6 rules for dynamic routes + // Generate v6 counterpart for dynamic routes and 0.0.0.0/0 exit node routes. + isDefaultV4 := !isV6Route && route.Network.Bits() == 0 + if includeIPv6 && (route.IsDynamic() || isDefaultV4) && len(v6Sources) > 0 { + v6Rule := baseRule + v6Rule.SourceRanges = v6Sources + if isDefaultV4 { + v6Rule.Destination = "::/0" + v6Rule.RouteID = route.ID + "-v6-default" + } + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(v6Rule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, v6Rule, rule, rulesExists)...) + } + } return rules } +// splitPeerSourcesByFamily separates peer IPs into v4 (/32) and v6 (/128) source ranges. +func splitPeerSourcesByFamily(groupPeers []*nbpeer.Peer) (v4, v6 []string) { + v4 = make([]string, 0, len(groupPeers)) + v6 = make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + v4 = append(v4, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + if peer.IPv6.IsValid() { + v6 = append(v6, fmt.Sprintf(AllowedIPsV6Format, peer.IPv6)) + } + } + return +} + // generateRulesForPeer generates rules for a given peer based on ports and port ranges. func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { rules := make([]*RouteFirewallRule, 0) diff --git a/management/server/types/firewall_rule_test.go b/management/server/types/firewall_rule_test.go new file mode 100644 index 000000000..8d97a46bc --- /dev/null +++ b/management/server/types/firewall_rule_test.go @@ -0,0 +1,197 @@ +package types + +import ( + "context" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" +) + +func TestSplitPeerSourcesByFamily(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + }, + { + IP: netip.MustParseAddr("100.64.0.3"), + IPv6: netip.MustParseAddr("fd00::3"), + }, + nil, + } + + v4, v6 := splitPeerSourcesByFamily(peers) + + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32", "100.64.0.3/32"}, v4) + assert.Equal(t, []string{"fd00::1/128", "fd00::3/128"}, v6) +} + +func TestGenerateRouteFirewallRules_V4Route(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + }, + } + + r := &route.Route{ + ID: "route1", + Network: netip.MustParsePrefix("10.0.0.0/24"), + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) + + require.Len(t, rules, 1) + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges, "v4 route should only have v4 sources") + assert.Equal(t, "10.0.0.0/24", rules[0].Destination) +} + +func TestGenerateRouteFirewallRules_V6Route(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + }, + } + + r := &route.Route{ + ID: "route1", + Network: netip.MustParsePrefix("2001:db8::/32"), + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) + + require.Len(t, rules, 1) + assert.Equal(t, []string{"fd00::1/128"}, rules[0].SourceRanges, "v6 route should only have v6 sources") +} + +func TestGenerateRouteFirewallRules_DynamicRoute_DualStack(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + }, + } + + r := &route.Route{ + ID: "route1", + NetworkType: route.DomainNetwork, + Domains: domain.List{"example.com"}, + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) + + require.Len(t, rules, 2, "dynamic route should produce both v4 and v6 rules") + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges) + assert.Equal(t, []string{"fd00::1/128"}, rules[1].SourceRanges) + assert.Equal(t, rules[0].Domains, rules[1].Domains) + assert.True(t, rules[0].IsDynamic) + assert.True(t, rules[1].IsDynamic) +} + +func TestGenerateRouteFirewallRules_DynamicRoute_NoV6Peers(t *testing.T) { + peers := []*nbpeer.Peer{ + {IP: netip.MustParseAddr("100.64.0.1")}, + {IP: netip.MustParseAddr("100.64.0.2")}, + } + + r := &route.Route{ + ID: "route1", + NetworkType: route.DomainNetwork, + Domains: domain.List{"example.com"}, + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) + + require.Len(t, rules, 1, "no v6 peers means only v4 rule") + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges) +} + +func TestGenerateRouteFirewallRules_IncludeIPv6False(t *testing.T) { + peers := []*nbpeer.Peer{ + { + IP: netip.MustParseAddr("100.64.0.1"), + IPv6: netip.MustParseAddr("fd00::1"), + }, + { + IP: netip.MustParseAddr("100.64.0.2"), + IPv6: netip.MustParseAddr("fd00::2"), + }, + } + + t.Run("v6 route excluded", func(t *testing.T) { + r := &route.Route{ + ID: "route1", + Network: netip.MustParsePrefix("2001:db8::/32"), + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false) + assert.Empty(t, rules, "v6 route should produce no rules when includeIPv6 is false") + }) + + t.Run("dynamic route only v4", func(t *testing.T) { + r := &route.Route{ + ID: "route1", + NetworkType: route.DomainNetwork, + Domains: domain.List{"example.com"}, + } + rule := &PolicyRule{ + PolicyID: "policy1", + ID: "rule1", + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, + } + + rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false) + require.Len(t, rules, 1, "dynamic route with includeIPv6=false should produce only v4 rule") + assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges) + }) +} diff --git a/management/server/types/ipv6_endtoend_test.go b/management/server/types/ipv6_endtoend_test.go new file mode 100644 index 000000000..ddd1f649f --- /dev/null +++ b/management/server/types/ipv6_endtoend_test.go @@ -0,0 +1,156 @@ +package types_test + +import ( + "net/netip" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func TestNetworkMapComponents_IPv6EndToEnd(t *testing.T) { + account := createComponentTestAccount() + + // Make all peers IPv6-capable and assign IPv6 addrs. + v6Caps := []int32{nbpeer.PeerCapabilityIPv6Overlay, nbpeer.PeerCapabilitySourcePrefixes} + account.Peers["peer-src-1"].Meta.Capabilities = v6Caps + account.Peers["peer-src-1"].IPv6 = netip.MustParseAddr("fd00::1") + account.Peers["peer-src-2"].Meta.Capabilities = v6Caps + account.Peers["peer-src-2"].IPv6 = netip.MustParseAddr("fd00::2") + account.Peers["peer-dst-1"].Meta.Capabilities = v6Caps + account.Peers["peer-dst-1"].IPv6 = netip.MustParseAddr("fd00::3") + + // Mark group-src and group-dst as IPv6-enabled. + account.Settings.IPv6EnabledGroups = []string{"group-src", "group-dst"} + + validated := allPeersValidated(account) + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + require.NotNil(t, nm) + + t.Run("v6 AAAA records emitted", func(t *testing.T) { + require.NotEmpty(t, nm.DNSConfig.CustomZones, "expected at least one custom zone") + var hasAAAA bool + var hasA bool + for _, z := range nm.DNSConfig.CustomZones { + for _, r := range z.Records { + if r.Type == int(dns.TypeAAAA) { + hasAAAA = true + } + if r.Type == int(dns.TypeA) { + hasA = true + } + } + } + assert.True(t, hasA, "expected A records") + assert.True(t, hasAAAA, "expected AAAA records for IPv6-enabled peers") + }) + + t.Run("v6 AllowedIPs would be advertised", func(t *testing.T) { + // nm.Peers contains *nbpeer.Peer; IPv6 should be set on those peers + var foundV6 bool + for _, p := range nm.Peers { + if p.IPv6.IsValid() { + foundV6 = true + } + } + assert.True(t, foundV6, "remote peers should have IPv6 set so AllowedIPs gets v6") + }) + + t.Run("v6 firewall rules emitted", func(t *testing.T) { + require.NotEmpty(t, nm.FirewallRules, "expected firewall rules") + var hasV4 bool + var hasV6 bool + for _, r := range nm.FirewallRules { + addr, err := netip.ParseAddr(r.PeerIP) + if err != nil { + continue + } + if addr.Is4() { + hasV4 = true + } + if addr.Is6() { + hasV6 = true + } + } + assert.True(t, hasV4, "expected at least one v4 firewall rule (peer IP)") + assert.True(t, hasV6, "expected at least one v6 firewall rule (peer IPv6)") + }) +} + +// TestNetworkMapComponents_RemotePeerWithoutCapability checks the asymmetric +// case where the target peer is IPv6-capable but a remote peer has an IPv6 +// address assigned in the DB without yet reporting the capability flag. +// In that case the remote peer's v6 still appears in AllowedIPs (gated on +// the target peer's capability) but its AAAA record does not (gated on the +// remote peer's own capability). +func TestNetworkMapComponents_RemotePeerWithoutCapability(t *testing.T) { + account := createComponentTestAccount() + + v6Caps := []int32{nbpeer.PeerCapabilityIPv6Overlay, nbpeer.PeerCapabilitySourcePrefixes} + // Target is fully capable. + account.Peers["peer-src-1"].Meta.Capabilities = v6Caps + account.Peers["peer-src-1"].IPv6 = netip.MustParseAddr("fd00::1") + // Remote peer has v6 assigned but no capability flag yet (e.g. old client). + account.Peers["peer-dst-1"].IPv6 = netip.MustParseAddr("fd00::3") + + account.Settings.IPv6EnabledGroups = []string{"group-src", "group-dst"} + + validated := allPeersValidated(account) + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + require.NotNil(t, nm) + + t.Run("AllowedIPs include remote v6", func(t *testing.T) { + var dst *nbpeer.Peer + for _, p := range nm.Peers { + if p.ID == "peer-dst-1" { + dst = p + } + } + require.NotNil(t, dst) + assert.True(t, dst.IPv6.IsValid(), "remote peer's v6 should still be present so AllowedIPs gets v6/128 (gated on target peer cap)") + }) + + t.Run("no AAAA for non-capable remote peer", func(t *testing.T) { + for _, z := range nm.DNSConfig.CustomZones { + for _, r := range z.Records { + if r.Type == int(dns.TypeAAAA) && r.RData == "fd00::3" { + t.Errorf("AAAA record for non-capable remote peer should NOT be emitted, got %+v", r) + } + } + } + }) +} + +// TestNetworkMapComponents_IPv6Disabled_NoV6Output asserts that a peer that +// does not support IPv6 (e.g. older client without the capability flag) gets +// no v6 firewall rules and no AAAA records, even if other peers have IPv6. +func TestNetworkMapComponents_IPv6Disabled_NoV6Output(t *testing.T) { + account := createComponentTestAccount() + + v6Caps := []int32{nbpeer.PeerCapabilityIPv6Overlay} + account.Peers["peer-src-2"].Meta.Capabilities = v6Caps + account.Peers["peer-src-2"].IPv6 = netip.MustParseAddr("fd00::2") + account.Peers["peer-dst-1"].Meta.Capabilities = v6Caps + account.Peers["peer-dst-1"].IPv6 = netip.MustParseAddr("fd00::3") + // peer-src-1 (target) intentionally has no capability and no IPv6. + + account.Settings.IPv6EnabledGroups = []string{"group-src", "group-dst"} + + validated := allPeersValidated(account) + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + require.NotNil(t, nm) + + t.Run("no v6 firewall rules", func(t *testing.T) { + for _, r := range nm.FirewallRules { + addr, err := netip.ParseAddr(r.PeerIP) + if err != nil { + continue + } + assert.False(t, addr.Is6(), "v6 firewall rules should not be emitted for non-IPv6 peer (got %s)", r.PeerIP) + } + }) +} diff --git a/management/server/types/ipv6_groups_test.go b/management/server/types/ipv6_groups_test.go new file mode 100644 index 000000000..5151e1b1f --- /dev/null +++ b/management/server/types/ipv6_groups_test.go @@ -0,0 +1,234 @@ +package types + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func TestPeerIPv6Allowed(t *testing.T) { + account := &Account{ + Groups: map[string]*Group{ + "group-all": {ID: "group-all", Name: "All", Peers: []string{"peer1", "peer2", "peer3"}}, + "group-devs": {ID: "group-devs", Name: "Devs", Peers: []string{"peer1", "peer2"}}, + "group-infra": {ID: "group-infra", Name: "Infra", Peers: []string{"peer2", "peer3"}}, + "group-empty": {ID: "group-empty", Name: "Empty", Peers: []string{}}, + }, + Settings: &Settings{}, + } + + tests := []struct { + name string + enabledGroups []string + peerID string + expected bool + }{ + { + name: "empty groups list disables IPv6 for all", + enabledGroups: []string{}, + peerID: "peer1", + expected: false, + }, + { + name: "All group enables IPv6 for everyone", + enabledGroups: []string{"group-all"}, + peerID: "peer1", + expected: true, + }, + { + name: "peer in enabled group gets IPv6", + enabledGroups: []string{"group-devs"}, + peerID: "peer1", + expected: true, + }, + { + name: "peer not in any enabled group denied IPv6", + enabledGroups: []string{"group-devs"}, + peerID: "peer3", + expected: false, + }, + { + name: "peer in multiple groups, one enabled", + enabledGroups: []string{"group-infra"}, + peerID: "peer2", + expected: true, + }, + { + name: "peer in multiple groups, other one enabled", + enabledGroups: []string{"group-devs"}, + peerID: "peer2", + expected: true, + }, + { + name: "multiple enabled groups, peer in one", + enabledGroups: []string{"group-devs", "group-infra"}, + peerID: "peer1", + expected: true, + }, + { + name: "multiple enabled groups, peer in both", + enabledGroups: []string{"group-devs", "group-infra"}, + peerID: "peer2", + expected: true, + }, + { + name: "nonexistent group ID in enabled list", + enabledGroups: []string{"group-deleted"}, + peerID: "peer1", + expected: false, + }, + { + name: "empty group in enabled list", + enabledGroups: []string{"group-empty"}, + peerID: "peer1", + expected: false, + }, + { + name: "unknown peer ID", + enabledGroups: []string{"group-all"}, + peerID: "peer-unknown", + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + account.Settings.IPv6EnabledGroups = tc.enabledGroups + result := account.PeerIPv6Allowed(tc.peerID) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestIPv6RecalculationOnGroupChange(t *testing.T) { + peerWithV6 := func(id string, v6 string) *nbpeer.Peer { + p := &nbpeer.Peer{ + ID: id, + IP: netip.MustParseAddr("100.64.0.1"), + } + if v6 != "" { + p.IPv6 = netip.MustParseAddr(v6) + } + return p + } + + t.Run("peer loses IPv6 when removed from enabled groups", func(t *testing.T) { + peer := peerWithV6("peer1", "fd00::1") + + account := &Account{ + Peers: map[string]*nbpeer.Peer{"peer1": peer}, + Groups: map[string]*Group{ + "group-a": {ID: "group-a", Peers: []string{"peer1"}}, + "group-b": {ID: "group-b", Peers: []string{}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-a"}, + }, + } + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer should be allowed before change") + + // Move peer out of enabled group + account.Groups["group-a"].Peers = []string{} + account.Groups["group-b"].Peers = []string{"peer1"} + + assert.False(t, account.PeerIPv6Allowed("peer1"), "peer should be denied after group change") + }) + + t.Run("peer gains IPv6 when added to enabled group", func(t *testing.T) { + peer := peerWithV6("peer1", "") + + account := &Account{ + Peers: map[string]*nbpeer.Peer{"peer1": peer}, + Groups: map[string]*Group{ + "group-a": {ID: "group-a", Peers: []string{}}, + "group-b": {ID: "group-b", Peers: []string{"peer1"}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-a"}, + }, + } + + assert.False(t, account.PeerIPv6Allowed("peer1"), "peer should be denied before change") + + // Add peer to enabled group + account.Groups["group-a"].Peers = []string{"peer1"} + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer should be allowed after joining enabled group") + }) + + t.Run("peer in two groups, one leaves enabled list", func(t *testing.T) { + peer := peerWithV6("peer1", "fd00::1") + + account := &Account{ + Peers: map[string]*nbpeer.Peer{"peer1": peer}, + Groups: map[string]*Group{ + "group-a": {ID: "group-a", Peers: []string{"peer1"}}, + "group-b": {ID: "group-b", Peers: []string{"peer1"}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-a", "group-b"}, + }, + } + + assert.True(t, account.PeerIPv6Allowed("peer1")) + + // Remove group-a from enabled list, peer still in group-b + account.Settings.IPv6EnabledGroups = []string{"group-b"} + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer should still be allowed via group-b") + }) + + t.Run("peer in two groups, both leave enabled list", func(t *testing.T) { + peer := peerWithV6("peer1", "fd00::1") + + account := &Account{ + Peers: map[string]*nbpeer.Peer{"peer1": peer}, + Groups: map[string]*Group{ + "group-a": {ID: "group-a", Peers: []string{"peer1"}}, + "group-b": {ID: "group-b", Peers: []string{"peer1"}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-a", "group-b"}, + }, + } + + assert.True(t, account.PeerIPv6Allowed("peer1")) + + // Clear all enabled groups + account.Settings.IPv6EnabledGroups = []string{} + + assert.False(t, account.PeerIPv6Allowed("peer1"), "peer should be denied when no groups enabled") + }) + + t.Run("enabling a group gives only its peers IPv6", func(t *testing.T) { + account := &Account{ + Peers: map[string]*nbpeer.Peer{ + "peer1": peerWithV6("peer1", ""), + "peer2": peerWithV6("peer2", ""), + "peer3": peerWithV6("peer3", ""), + }, + Groups: map[string]*Group{ + "group-devs": {ID: "group-devs", Peers: []string{"peer1", "peer2"}}, + "group-infra": {ID: "group-infra", Peers: []string{"peer2", "peer3"}}, + }, + Settings: &Settings{ + IPv6EnabledGroups: []string{"group-devs"}, + }, + } + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer1 in devs") + assert.True(t, account.PeerIPv6Allowed("peer2"), "peer2 in devs") + assert.False(t, account.PeerIPv6Allowed("peer3"), "peer3 not in devs") + + // Add infra group + account.Settings.IPv6EnabledGroups = []string{"group-devs", "group-infra"} + + assert.True(t, account.PeerIPv6Allowed("peer1"), "peer1 still in devs") + assert.True(t, account.PeerIPv6Allowed("peer2"), "peer2 in both") + assert.True(t, account.PeerIPv6Allowed("peer3"), "peer3 now in infra") + }) +} diff --git a/management/server/types/network.go b/management/server/types/network.go index 0d13de10f..fe67bfd97 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -2,8 +2,11 @@ package types import ( "encoding/binary" + "fmt" "math/rand" "net" + "net/netip" + "slices" "sync" "time" @@ -27,6 +30,12 @@ const ( // AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.64.30.1/32) AllowedIPsFormat = "%s/32" + // AllowedIPsV6Format generates AllowedIPs format for v6 (e.g. fd12:3456:7890::1/128) + AllowedIPsV6Format = "%s/128" + + // IPv6SubnetSize is the prefix length of per-account IPv6 subnets. + // Each account gets a /64 from its unique /48 ULA prefix. + IPv6SubnetSize = 64 ) type NetworkMap struct { @@ -111,7 +120,9 @@ func ipToBytes(ip net.IP) []byte { type Network struct { Identifier string `json:"id"` Net net.IPNet `gorm:"serializer:json"` - Dns string + // NetV6 is the IPv6 ULA subnet for this account's overlay. Empty if not yet allocated. + NetV6 net.IPNet `gorm:"serializer:json"` + Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. Serial uint64 @@ -121,20 +132,45 @@ type Network struct { // NewNetwork creates a new Network initializing it with a Serial=0 // It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets) +// and a random /64 subnet from fd00:4e42::/32 for IPv6. func NewNetwork() *Network { - n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize) sub, _ := n.Subnet(SubnetSize) - s := rand.NewSource(time.Now().Unix()) + s := rand.NewSource(time.Now().UnixNano()) r := rand.New(s) intn := r.Intn(len(sub)) return &Network{ Identifier: xid.New().String(), Net: sub[intn].IPNet, + NetV6: AllocateIPv6Subnet(r), Dns: "", - Serial: 0} + Serial: 0, + } +} + +// AllocateIPv6Subnet generates a random RFC 4193 ULA /64 prefix. +// The format follows RFC 4193 section 3.1: fd + 40-bit Global ID + 16-bit Subnet ID. +// The Global ID and Subnet ID are randomized (simplified from the SHA-1 algorithm +// in section 3.2.2), giving 2^56 possible /64 subnets across all accounts. +func AllocateIPv6Subnet(r *rand.Rand) net.IPNet { + ip := make(net.IP, 16) + ip[0] = 0xfd + // Bytes 1-5: 40-bit random Global ID + ip[1] = byte(r.Intn(256)) + ip[2] = byte(r.Intn(256)) + ip[3] = byte(r.Intn(256)) + ip[4] = byte(r.Intn(256)) + ip[5] = byte(r.Intn(256)) + // Bytes 6-7: 16-bit random Subnet ID + ip[6] = byte(r.Intn(256)) + ip[7] = byte(r.Intn(256)) + + return net.IPNet{ + IP: ip, + Mask: net.CIDRMask(IPv6SubnetSize, 128), + } } // IncSerial increments Serial by 1 reflecting that the network state has been changed @@ -157,19 +193,19 @@ func (n *Network) Copy() *Network { return &Network{ Identifier: n.Identifier, Net: n.Net, + NetV6: n.NetV6, Dns: n.Dns, Serial: n.Serial, } } -// AllocatePeerIP pics an available IP from an net.IPNet. -// This method considers already taken IPs and reuses IPs if there are gaps in takenIps -// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3 -func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { - baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask)) - - ones, bits := ipNet.Mask.Size() - hostBits := bits - ones +// AllocatePeerIP picks an available IP from a netip.Prefix. +// This method considers already taken IPs and reuses IPs if there are gaps in takenIps. +// E.g. if prefix=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3. +func AllocatePeerIP(prefix netip.Prefix, takenIps []netip.Addr) (netip.Addr, error) { + b := prefix.Masked().Addr().As4() + baseIP := binary.BigEndian.Uint32(b[:]) + hostBits := 32 - prefix.Bits() totalIPs := uint32(1 << hostBits) taken := make(map[uint32]struct{}, len(takenIps)+1) @@ -177,7 +213,8 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { taken[baseIP+totalIPs-1] = struct{}{} // reserve broadcast IP for _, ip := range takenIps { - taken[ipToUint32(ip)] = struct{}{} + ab := ip.As4() + taken[binary.BigEndian.Uint32(ab[:])] = struct{}{} } rng := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -198,15 +235,14 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { } } - return nil, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", ipNet.String()) + return netip.Addr{}, status.Errorf(status.PreconditionFailed, "network %s is out of IPs", prefix.String()) } -func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) { - baseIP := ipToUint32(ipNet.IP.Mask(ipNet.Mask)) - - ones, bits := ipNet.Mask.Size() - hostBits := bits - ones - +// AllocateRandomPeerIP picks a random available IP from a netip.Prefix. +func AllocateRandomPeerIP(prefix netip.Prefix) (netip.Addr, error) { + b := prefix.Masked().Addr().As4() + baseIP := binary.BigEndian.Uint32(b[:]) + hostBits := 32 - prefix.Bits() totalIPs := uint32(1 << hostBits) rng := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -216,18 +252,75 @@ func AllocateRandomPeerIP(ipNet net.IPNet) (net.IP, error) { return uint32ToIP(candidate), nil } -func ipToUint32(ip net.IP) uint32 { - ip = ip.To4() - if len(ip) < 4 { - return 0 +// AllocateRandomPeerIPv6 picks a random host address within the given IPv6 prefix. +// Only the host bits (after the prefix length) are randomized. +func AllocateRandomPeerIPv6(prefix netip.Prefix) (netip.Addr, error) { + ones := prefix.Bits() + if ones == 0 || ones > 126 || !prefix.Addr().Is6() { + return netip.Addr{}, fmt.Errorf("invalid IPv6 subnet: %s", prefix.String()) } - return binary.BigEndian.Uint32(ip) + + ip := prefix.Addr().As16() + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + + // Determine which byte the host bits start in + firstHostByte := ones / 8 + // If the prefix doesn't end on a byte boundary, handle the partial byte + partialBits := ones % 8 + + if partialBits > 0 { + // Keep the network bits in the partial byte, randomize the rest + hostMask := byte(0xff >> partialBits) + ip[firstHostByte] = (ip[firstHostByte] & ^hostMask) | (byte(rng.Intn(256)) & hostMask) + firstHostByte++ + } + + // Randomize remaining full host bytes + for i := firstHostByte; i < 16; i++ { + ip[i] = byte(rng.Intn(256)) + } + + // Avoid all-zeros and all-ones host parts by checking only host bits. + if isHostAllZeroOrOnes(ip[:], ones) { + ip = prefix.Masked().Addr().As16() + ip[15] |= 0x01 + } + + return netip.AddrFrom16(ip).Unmap(), nil } -func uint32ToIP(n uint32) net.IP { - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, n) - return ip +// isHostAllZeroOrOnes checks whether all host bits (after prefixLen) are zero or all ones. +func isHostAllZeroOrOnes(ip []byte, prefixLen int) bool { + hostStart := prefixLen / 8 + partialBits := prefixLen % 8 + + hostSlice := slices.Clone(ip[hostStart:]) + if partialBits > 0 { + hostSlice[0] &= 0xff >> partialBits + } + + allZero := !slices.ContainsFunc(hostSlice, func(v byte) bool { return v != 0 }) + if allZero { + return true + } + + // Build the all-ones mask for host bits + onesMask := make([]byte, len(hostSlice)) + for i := range onesMask { + onesMask[i] = 0xff + } + if partialBits > 0 { + onesMask[0] = 0xff >> partialBits + } + + return slices.Equal(hostSlice, onesMask) +} + +func uint32ToIP(n uint32) netip.Addr { + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + return netip.AddrFrom4(b) } // generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list diff --git a/management/server/types/network_test.go b/management/server/types/network_test.go index 4c1459ce5..d8a06dbbc 100644 --- a/management/server/types/network_test.go +++ b/management/server/types/network_test.go @@ -1,7 +1,9 @@ package types import ( + "encoding/binary" "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -17,10 +19,10 @@ func TestNewNetwork(t *testing.T) { } func TestAllocatePeerIP(t *testing.T) { - ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 255, 255, 0}} - var ips []net.IP + prefix := netip.MustParsePrefix("100.64.0.0/24") + var ips []netip.Addr for i := 0; i < 252; i++ { - ip, err := AllocatePeerIP(ipNet, ips) + ip, err := AllocatePeerIP(prefix, ips) if err != nil { t.Fatal(err) } @@ -41,19 +43,19 @@ func TestAllocatePeerIP(t *testing.T) { func TestAllocatePeerIPSmallSubnet(t *testing.T) { // Test /27 network (10.0.0.0/27) - should only have 30 usable IPs (10.0.0.1 to 10.0.0.30) - ipNet := net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.IPMask{255, 255, 255, 224}} - var ips []net.IP + prefix := netip.MustParsePrefix("10.0.0.0/27") + var ips []netip.Addr // Allocate all available IPs in the /27 network for i := 0; i < 30; i++ { - ip, err := AllocatePeerIP(ipNet, ips) + ip, err := AllocatePeerIP(prefix, ips) if err != nil { t.Fatal(err) } // Verify IP is within the correct range - if !ipNet.Contains(ip) { - t.Errorf("allocated IP %s is not within network %s", ip.String(), ipNet.String()) + if !prefix.Contains(ip) { + t.Errorf("allocated IP %s is not within network %s", ip.String(), prefix.String()) } ips = append(ips, ip) @@ -72,7 +74,7 @@ func TestAllocatePeerIPSmallSubnet(t *testing.T) { } // Try to allocate one more IP - should fail as network is full - _, err := AllocatePeerIP(ipNet, ips) + _, err := AllocatePeerIP(prefix, ips) if err == nil { t.Error("expected error when network is full, but got none") } @@ -95,10 +97,11 @@ func TestAllocatePeerIPVariousCIDRs(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - _, ipNet, err := net.ParseCIDR(tc.cidr) + prefix, err := netip.ParsePrefix(tc.cidr) require.NoError(t, err) + prefix = prefix.Masked() - var ips []net.IP + var ips []netip.Addr // For larger networks, test only a subset to avoid long test runs testCount := tc.expectedUsable @@ -108,21 +111,21 @@ func TestAllocatePeerIPVariousCIDRs(t *testing.T) { // Allocate IPs and verify they're within the correct range for i := 0; i < testCount; i++ { - ip, err := AllocatePeerIP(*ipNet, ips) + ip, err := AllocatePeerIP(prefix, ips) require.NoError(t, err, "failed to allocate IP %d", i) // Verify IP is within the correct range - assert.True(t, ipNet.Contains(ip), "allocated IP %s is not within network %s", ip.String(), ipNet.String()) + assert.True(t, prefix.Contains(ip), "allocated IP %s is not within network %s", ip.String(), prefix.String()) // Verify IP is not network or broadcast address - networkIP := ipNet.IP.Mask(ipNet.Mask) - ones, bits := ipNet.Mask.Size() - hostBits := bits - ones - broadcastInt := uint32(ipToUint32(networkIP)) + (1 << hostBits) - 1 - broadcastIP := uint32ToIP(broadcastInt) + networkAddr := prefix.Masked().Addr() + hostBits := 32 - prefix.Bits() + b := networkAddr.As4() + baseIP := binary.BigEndian.Uint32(b[:]) + broadcastIP := uint32ToIP(baseIP + (1 << hostBits) - 1) - assert.False(t, ip.Equal(networkIP), "allocated network address %s", ip.String()) - assert.False(t, ip.Equal(broadcastIP), "allocated broadcast address %s", ip.String()) + assert.NotEqual(t, networkAddr, ip, "allocated network address %s", ip.String()) + assert.NotEqual(t, broadcastIP, ip, "allocated broadcast address %s", ip.String()) ips = append(ips, ip) } @@ -151,3 +154,111 @@ func TestGenerateIPs(t *testing.T) { t.Errorf("expected last ip to be: 100.64.0.253, got %s", ips[len(ips)-1].String()) } } + +func TestNewNetworkHasIPv6(t *testing.T) { + network := NewNetwork() + + assert.NotNil(t, network.NetV6.IP, "v6 subnet should be allocated") + assert.True(t, network.NetV6.IP.To4() == nil, "v6 subnet should be IPv6") + assert.Equal(t, byte(0xfd), network.NetV6.IP[0], "v6 subnet should be ULA (fd prefix)") + + ones, bits := network.NetV6.Mask.Size() + assert.Equal(t, 64, ones, "v6 subnet should be /64") + assert.Equal(t, 128, bits) +} + +func TestAllocateIPv6SubnetUniqueness(t *testing.T) { + seen := make(map[string]struct{}) + for i := 0; i < 100; i++ { + network := NewNetwork() + key := network.NetV6.IP.String() + _, duplicate := seen[key] + assert.False(t, duplicate, "duplicate v6 subnet: %s", key) + seen[key] = struct{}{} + } +} + +func TestAllocateRandomPeerIPv6(t *testing.T) { + prefix := netip.MustParsePrefix("fd12:3456:7890:abcd::/64") + + ip, err := AllocateRandomPeerIPv6(prefix) + require.NoError(t, err) + + assert.True(t, ip.Is6(), "should be IPv6") + assert.True(t, prefix.Contains(ip), "should be within subnet") + // First 8 bytes (network prefix) should match + b := ip.As16() + prefixBytes := prefix.Addr().As16() + assert.Equal(t, prefixBytes[:8], b[:8], "prefix should match") + // Interface ID should not be all zeros + allZero := true + for _, v := range b[8:] { + if v != 0 { + allZero = false + break + } + } + assert.False(t, allZero, "interface ID should not be all zeros") +} + +func TestAllocateRandomPeerIPv6_VariousPrefixes(t *testing.T) { + tests := []struct { + name string + cidr string + prefix int + }{ + {"standard /64", "fd00:1234:5678:abcd::/64", 64}, + {"small /112", "fd00:1234:5678:abcd::/112", 112}, + {"large /48", "fd00:1234::/48", 48}, + {"non-boundary /60", "fd00:1234:5670::/60", 60}, + {"non-boundary /52", "fd00:1230::/52", 52}, + {"minimum /120", "fd00:1234:5678:abcd::100/120", 120}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefix, err := netip.ParsePrefix(tt.cidr) + require.NoError(t, err) + prefix = prefix.Masked() + + assert.Equal(t, tt.prefix, prefix.Bits()) + + for i := 0; i < 50; i++ { + ip, err := AllocateRandomPeerIPv6(prefix) + require.NoError(t, err) + assert.True(t, prefix.Contains(ip), "IP %s should be within %s", ip, prefix) + } + }) + } +} + +func TestAllocateRandomPeerIPv6_PreservesNetworkBits(t *testing.T) { + // For a /112, bytes 0-13 should be preserved, only bytes 14-15 should vary + prefix := netip.MustParsePrefix("fd00:1234:5678:abcd:ef01:2345:6789:0/112") + + prefixBytes := prefix.Addr().As16() + for i := 0; i < 20; i++ { + ip, err := AllocateRandomPeerIPv6(prefix) + require.NoError(t, err) + // First 14 bytes (112 bits = 14 bytes) must match the network + b := ip.As16() + assert.Equal(t, prefixBytes[:14], b[:14], "network bytes should be preserved for /112") + } +} + +func TestAllocateRandomPeerIPv6_NonByteBoundary(t *testing.T) { + // For a /60, the first 7.5 bytes are network, so byte 7 is partial + prefix := netip.MustParsePrefix("fd00:1234:5678:abc0::/60") + + prefixBytes := prefix.Addr().As16() + for i := 0; i < 50; i++ { + ip, err := AllocateRandomPeerIPv6(prefix) + require.NoError(t, err) + b := ip.As16() + assert.True(t, prefix.Contains(ip), "IP %s should be within %s", ip, prefix) + // First 7 bytes must match exactly + assert.Equal(t, prefixBytes[:7], b[:7], "full network bytes should match for /60") + // Byte 7: top 4 bits (0xc = 1100) must be preserved + assert.Equal(t, prefixBytes[7]&0xf0, b[7]&0xf0, "partial byte network bits should be preserved for /60") + } +} diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go index 6f84c8d30..3a7e20ec5 100644 --- a/management/server/types/networkmap_components.go +++ b/management/server/types/networkmap_components.go @@ -3,7 +3,6 @@ package types import ( "context" "maps" - "net" "net/netip" "slices" "strconv" @@ -114,13 +113,17 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { peersToConnect, expiredPeers := c.filterPeersByLoginExpiration(aclPeers) - routesUpdate := c.getRoutesToSync(targetPeerID, peersToConnect, peerGroups) - routesFirewallRules := c.getPeerRoutesFirewallRules(ctx, targetPeerID) + includeIPv6 := false + if p := c.Peers[targetPeerID]; p != nil { + includeIPv6 = p.SupportsIPv6() && p.IPv6.IsValid() + } + routesUpdate := filterAndExpandRoutes(c.getRoutesToSync(targetPeerID, peersToConnect, peerGroups), includeIPv6) + routesFirewallRules := c.getPeerRoutesFirewallRules(ctx, targetPeerID, includeIPv6) isRouter, networkResourcesRoutes, sourcePeers := c.getNetworkResourcesRoutesToSync(targetPeerID) var networkResourcesFirewallRules []*RouteFirewallRule if isRouter { - networkResourcesFirewallRules = c.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes) + networkResourcesFirewallRules = c.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes, includeIPv6) } peersToConnectIncludingRouters := c.addNetworksRoutingPeers( @@ -156,7 +159,7 @@ func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { return &NetworkMap{ Peers: peersToConnectIncludingRouters, Network: c.Network.Copy(), - Routes: append(networkResourcesRoutes, routesUpdate...), + Routes: append(filterAndExpandRoutes(networkResourcesRoutes, includeIPv6), routesUpdate...), DNSConfig: dnsUpdate, OfflinePeers: expiredPeers, FirewallRules: firewallRules, @@ -296,7 +299,7 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( peersExists[peer.ID] = struct{}{} } - peerIP := net.IP(peer.IP).String() + peerIP := peer.IP.String() fr := FirewallRule{ PolicyID: rule.ID, @@ -315,10 +318,17 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) ( if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { rules = append(rules, &fr) - continue + } else { + rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) } - rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) + rules = appendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, firewallRuleContext{ + direction: direction, + dirStr: dirStr, + protocolStr: protocolStr, + actionStr: actionStr, + portsJoined: portsJoined, + }) } }, func() ([]*nbpeer.Peer, []*FirewallRule) { return peers, rules @@ -454,6 +464,29 @@ func (c *NetworkMapComponents) peerIsNameserver(peerIPStr string, nsGroup *nbdns return false } +// filterAndExpandRoutes drops v6 routes for non-capable peers and duplicates +// the default v4 route (0.0.0.0/0) as ::/0 for v6-capable peers. +// TODO: the "-v6" suffix on IDs could collide with user-supplied route IDs. +func filterAndExpandRoutes(routes []*route.Route, includeIPv6 bool) []*route.Route { + filtered := make([]*route.Route, 0, len(routes)) + for _, r := range routes { + if !includeIPv6 && r.Network.Addr().Is6() { + continue + } + filtered = append(filtered, r) + + if includeIPv6 && r.Network.Bits() == 0 && r.Network.Addr().Is4() { + v6 := r.Copy() + v6.ID = r.ID + "-v6-default" + v6.NetID = r.NetID + "-v6" + v6.Network = netip.MustParsePrefix("::/0") + v6.NetworkType = route.IPv6Network + filtered = append(filtered, v6) + } + } + return filtered +} + func (c *NetworkMapComponents) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route { routes, peerDisabledRoutes := c.getRoutingPeerRoutes(peerID) peerRoutesMembership := make(LookupMap) @@ -550,13 +583,13 @@ func (c *NetworkMapComponents) filterRoutesFromPeersOfSameHAGroup(routes []*rout return filteredRoutes } -func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule { +func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, peerID string, includeIPv6 bool) []*RouteFirewallRule { routesFirewallRules := make([]*RouteFirewallRule, 0) enabledRoutes, _ := c.getRoutingPeerRoutes(peerID) for _, r := range enabledRoutes { if len(r.AccessControlGroups) == 0 { - defaultPermit := c.getDefaultPermit(r) + defaultPermit := c.getDefaultPermit(r, includeIPv6) routesFirewallRules = append(routesFirewallRules, defaultPermit...) continue } @@ -565,7 +598,7 @@ func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, p for _, accessGroup := range r.AccessControlGroups { policies := c.getAllRoutePoliciesFromGroups([]string{accessGroup}) - rules := c.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers) + rules := c.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers, includeIPv6) routesFirewallRules = append(routesFirewallRules, rules...) } } @@ -573,8 +606,10 @@ func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, p return routesFirewallRules } -func (c *NetworkMapComponents) getDefaultPermit(r *route.Route) []*RouteFirewallRule { - var rules []*RouteFirewallRule +func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool) []*RouteFirewallRule { + if r.Network.Addr().Is6() && !includeIPv6 { + return nil + } sources := []string{"0.0.0.0/0"} if r.Network.Addr().Is6() { @@ -591,9 +626,9 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route) []*RouteFirewall RouteID: r.ID, } - rules = append(rules, &rule) + rules := []*RouteFirewallRule{&rule} - if r.IsDynamic() { + if includeIPv6 && r.IsDynamic() { ruleV6 := rule ruleV6.SourceRanges = []string{"::/0"} rules = append(rules, &ruleV6) @@ -632,7 +667,7 @@ func (c *NetworkMapComponents) getAllRoutePoliciesFromGroups(accessControlGroups return routePolicies } -func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule { +func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}, includeIPv6 bool) []*RouteFirewallRule { var fwRules []*RouteFirewallRule for _, policy := range policies { if !policy.Enabled { @@ -645,7 +680,7 @@ func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID } rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers) - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6) fwRules = append(fwRules, rules...) } } @@ -710,33 +745,49 @@ func (c *NetworkMapComponents) getNetworkResourcesRoutesToSync(peerID string) (b } } - addedResourceRoute := false - for _, policy := range c.ResourcePoliciesMap[resource.ID] { - var peers []string - if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { - peers = []string{policy.Rules[0].SourceResource.ID} - } else { - peers = c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups()) - } - if addSourcePeers { - for _, pID := range c.getPostureValidPeers(peers, policy.SourcePostureChecks) { - allSourcePeers[pID] = struct{}{} - } - } else if slices.Contains(peers, peerID) && c.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) { - for peerId, router := range networkRoutingPeers { - routes = append(routes, c.getNetworkResourcesRoutes(resource, peerId, router)...) - } - addedResourceRoute = true - } - if addedResourceRoute { - break - } - } + newRoutes := c.processResourcePolicies(peerID, resource, networkRoutingPeers, addSourcePeers, allSourcePeers) + routes = append(routes, newRoutes...) } return isRoutingPeer, routes, allSourcePeers } +func (c *NetworkMapComponents) processResourcePolicies( + peerID string, + resource *resourceTypes.NetworkResource, + networkRoutingPeers map[string]*routerTypes.NetworkRouter, + addSourcePeers bool, + allSourcePeers map[string]struct{}, +) []*route.Route { + var routes []*route.Route + + for _, policy := range c.ResourcePoliciesMap[resource.ID] { + peers := c.getResourcePolicyPeers(policy) + if addSourcePeers { + for _, pID := range c.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } + continue + } + + if slices.Contains(peers, peerID) && c.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) { + for peerId, router := range networkRoutingPeers { + routes = append(routes, c.getNetworkResourcesRoutes(resource, peerId, router)...) + } + break + } + } + + return routes +} + +func (c *NetworkMapComponents) getResourcePolicyPeers(policy *Policy) []string { + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + return []string{policy.Rules[0].SourceResource.ID} + } + return c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups()) +} + func (c *NetworkMapComponents) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route { resourceAppliedPolicies := c.ResourcePoliciesMap[resource.ID] @@ -796,7 +847,7 @@ func (c *NetworkMapComponents) getPostureValidPeers(inputPeers []string, posture return dest } -func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule { +func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route, includeIPv6 bool) []*RouteFirewallRule { routesFirewallRules := make([]*RouteFirewallRule, 0) peerInfo := c.GetPeerInfo(peerID) @@ -813,7 +864,7 @@ func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.C resourcePolicies := c.ResourcePoliciesMap[resourceID] distributionPeers := c.getPoliciesSourcePeers(resourcePolicies) - rules := c.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers) + rules := c.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers, includeIPv6) for _, rule := range rules { if len(rule.SourceRanges) > 0 { routesFirewallRules = append(routesFirewallRules, rule) @@ -897,3 +948,36 @@ func (c *NetworkMapComponents) addNetworksRoutingPeers( return peersToConnect } + +type firewallRuleContext struct { + direction int + dirStr string + protocolStr string + actionStr string + portsJoined string +} + +func appendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct{}, peer, targetPeer *nbpeer.Peer, rule *PolicyRule, rc firewallRuleContext) []*FirewallRule { + if !peer.IPv6.IsValid() || !targetPeer.SupportsIPv6() || !targetPeer.IPv6.IsValid() { + return rules + } + + v6IP := peer.IPv6.String() + v6RuleID := rule.ID + v6IP + rc.dirStr + rc.protocolStr + rc.actionStr + rc.portsJoined + if _, ok := rulesExists[v6RuleID]; ok { + return rules + } + rulesExists[v6RuleID] = struct{}{} + + v6fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: v6IP, + Direction: rc.direction, + Action: rc.actionStr, + Protocol: rc.protocolStr, + } + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + return append(rules, &v6fr) + } + return append(rules, expandPortsAndRanges(v6fr, rule, targetPeer)...) +} diff --git a/management/server/types/networkmap_components_correctness_test.go b/management/server/types/networkmap_components_correctness_test.go index 5cd41ff10..bcfb6fdf9 100644 --- a/management/server/types/networkmap_components_correctness_test.go +++ b/management/server/types/networkmap_components_correctness_test.go @@ -42,7 +42,7 @@ func buildScalableTestAccount(numPeers, numGroups int, withDefaultPolicy bool) ( for i := range numPeers { peerID := fmt.Sprintf("peer-%d", i) - ip := net.IP{100, byte(64 + i/65536), byte((i / 256) % 256), byte(i % 256)} + ip := netip.AddrFrom4([4]byte{100, byte(64 + i/65536), byte((i / 256) % 256), byte(i % 256)}) wtVersion := "0.25.0" if i%2 == 0 { wtVersion = "0.40.0" @@ -1083,7 +1083,7 @@ func TestComponents_PeerIsNameserverExcludedFromNSGroup(t *testing.T) { nsIP := account.Peers["peer-0"].IP account.NameServerGroups["ns-self"] = &nbdns.NameServerGroup{ ID: "ns-self", Name: "Self NS", Enabled: true, Groups: []string{"group-all"}, - NameServers: []nbdns.NameServer{{IP: netip.AddrFrom4([4]byte{nsIP[0], nsIP[1], nsIP[2], nsIP[3]}), NSType: nbdns.UDPNameServerType, Port: 53}}, + NameServers: []nbdns.NameServer{{IP: nsIP, NSType: nbdns.UDPNameServerType, Port: 53}}, } nm := componentsNetworkMap(account, "peer-0", validatedPeers) diff --git a/management/server/types/networkmap_components_test.go b/management/server/types/networkmap_components_test.go index dde639ccb..1a99b4511 100644 --- a/management/server/types/networkmap_components_test.go +++ b/management/server/types/networkmap_components_test.go @@ -681,22 +681,22 @@ func TestNetworkMapComponents_RouterExcludesOtherNetworkRoutes(t *testing.T) { func createComponentTestAccount() *types.Account { peers := map[string]*nbpeer.Peer{ "peer-src-1": { - ID: "peer-src-1", IP: net.IP{100, 64, 0, 1}, Key: "key-src-1", DNSLabel: "src1", + ID: "peer-src-1", IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}), Key: "key-src-1", DNSLabel: "src1", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, }, "peer-src-2": { - ID: "peer-src-2", IP: net.IP{100, 64, 0, 2}, Key: "key-src-2", DNSLabel: "src2", + ID: "peer-src-2", IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}), Key: "key-src-2", DNSLabel: "src2", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, }, "peer-dst-1": { - ID: "peer-dst-1", IP: net.IP{100, 64, 0, 3}, Key: "key-dst-1", DNSLabel: "dst1", + ID: "peer-dst-1", IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}), Key: "key-dst-1", DNSLabel: "dst1", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-2", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, }, "peer-router-1": { - ID: "peer-router-1", IP: net.IP{100, 64, 0, 10}, Key: "key-router-1", DNSLabel: "router1", + ID: "peer-router-1", IP: netip.AddrFrom4([4]byte{100, 64, 0, 10}), Key: "key-router-1", DNSLabel: "router1", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, }, diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 4ea79ec72..264a018d4 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -46,6 +46,8 @@ type Settings struct { // NetworkRange is the custom network range for that account NetworkRange netip.Prefix `gorm:"serializer:json"` + // NetworkRangeV6 is the custom IPv6 network range for that account + NetworkRangeV6 netip.Prefix `gorm:"serializer:json"` // PeerExposeEnabled enables or disables peer-initiated service expose PeerExposeEnabled bool @@ -65,6 +67,12 @@ type Settings struct { // when false, updates require user interaction from the UI AutoUpdateAlways bool `gorm:"default:false"` + // IPv6EnabledGroups is the list of group IDs whose peers receive IPv6 overlay addresses. + // Peers not in any of these groups will not be allocated an IPv6 address. + // Empty list means IPv6 is disabled for the account. + // For new accounts this defaults to the All group. + IPv6EnabledGroups []string `gorm:"serializer:json"` + // EmbeddedIdpEnabled indicates if the embedded identity provider is enabled. // This is a runtime-only field, not stored in the database. EmbeddedIdpEnabled bool `gorm:"-"` @@ -94,8 +102,10 @@ func (s *Settings) Copy() *Settings { LazyConnectionEnabled: s.LazyConnectionEnabled, DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, + NetworkRangeV6: s.NetworkRangeV6, AutoUpdateVersion: s.AutoUpdateVersion, AutoUpdateAlways: s.AutoUpdateAlways, + IPv6EnabledGroups: slices.Clone(s.IPv6EnabledGroups), EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, LocalAuthDisabled: s.LocalAuthDisabled, } diff --git a/management/server/user.go b/management/server/user.go index 43e0a9821..892d982e7 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "time" "unicode" @@ -825,6 +826,11 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } } } + + allGroupChanges := slices.Concat(removedGroups, addedGroups) + if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, allGroupChanges); err != nil { + return false, nil, nil, nil, fmt.Errorf("reconcile IPv6 for group changes: %w", err) + } } updateAccountPeers := len(userPeers) > 0 diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 1b1664490..49afc7638 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" "syscall" + "time" "github.com/spf13/cobra" @@ -62,7 +63,11 @@ var debugSyncCmd = &cobra.Command{ SilenceUsage: true, } -var pingTimeout string +var ( + pingTimeout time.Duration + pingIPv4 bool + pingIPv6 bool +) var debugPingCmd = &cobra.Command{ Use: "ping [port]", @@ -134,7 +139,10 @@ func init() { debugStatusCmd.Flags().StringVar(&statusFilterByStatus, "filter-by-status", "", "Filter by status (idle|connecting|connected)") debugStatusCmd.Flags().StringVar(&statusFilterByConnectionType, "filter-by-connection-type", "", "Filter by connection type (P2P|Relayed)") - debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().DurationVar(&pingTimeout, "timeout", 0, "Ping timeout (e.g., 10s)") + debugPingCmd.Flags().BoolVarP(&pingIPv4, "ipv4", "4", false, "Force IPv4") + debugPingCmd.Flags().BoolVarP(&pingIPv6, "ipv6", "6", false, "Force IPv6") + debugPingCmd.MarkFlagsMutuallyExclusive("ipv4", "ipv6") debugCaptureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = server default)") debugCaptureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)") @@ -190,7 +198,14 @@ func runDebugPing(cmd *cobra.Command, args []string) error { } port = p } - return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout) + var ipVersion string + switch { + case pingIPv4: + ipVersion = "4" + case pingIPv6: + ipVersion = "6" + } + return getDebugClient(cmd).PingTCP(cmd.Context(), args[0], args[1], port, pingTimeout, ipVersion) } func runDebugLogLevel(cmd *cobra.Command, args []string) error { diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index e01149522..09c25afb2 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -6,10 +6,12 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/url" "strings" "time" + ) // StatusFilters contains filter options for status queries. @@ -230,12 +232,16 @@ func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error } // PingTCP performs a TCP ping through a client. -func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error { +// ipVersion may be "4", "6", or "" for automatic. +func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout time.Duration, ipVersion string) error { params := url.Values{} params.Set("host", host) params.Set("port", fmt.Sprintf("%d", port)) - if timeout != "" { - params.Set("timeout", timeout) + if timeout > 0 { + params.Set("timeout", timeout.String()) + } + if ipVersion != "" { + params.Set("ip_version", ipVersion) } path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode()) @@ -244,11 +250,17 @@ func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, func (c *Client) printPingResult(data map[string]any) { success, _ := data["success"].(bool) + host := net.JoinHostPort(fmt.Sprint(data["host"]), fmt.Sprint(data["port"])) if success { - _, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"]) + remote, _ := data["remote"].(string) + if remote != "" && remote != host { + _, _ = fmt.Fprintf(c.out, "Success: %s (via %s)\n", host, remote) + } else { + _, _ = fmt.Fprintf(c.out, "Success: %s\n", host) + } _, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"]) } else { - _, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"]) + _, _ = fmt.Fprintf(c.out, "Failed: %s\n", host) c.printError(data) } } diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index 6cd124554..23ca4adbb 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -9,6 +9,7 @@ import ( "fmt" "html/template" "maps" + "net" "net/http" "slices" "strconv" @@ -527,13 +528,18 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI } } + network := "tcp" + if v := r.URL.Query().Get("ip_version"); v == "4" || v == "6" { + network += v + } + ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() - address := fmt.Sprintf("%s:%d", host, port) + address := net.JoinHostPort(host, strconv.Itoa(port)) start := time.Now() - conn, err := client.Dial(ctx, "tcp", address) + conn, err := client.Dial(ctx, network, address) if err != nil { h.writeJSON(w, map[string]interface{}{ "success": false, @@ -543,18 +549,22 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI }) return } + + remote := conn.RemoteAddr().String() if err := conn.Close(); err != nil { h.logger.Debugf("close tcp ping connection: %v", err) } latency := time.Since(start) - h.writeJSON(w, map[string]interface{}{ + resp := map[string]interface{}{ "success": true, "host": host, "port": port, + "remote": remote, "latency_ms": latency.Milliseconds(), "latency": formatDuration(latency), - }) + } + h.writeJSON(w, resp) } func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go index 4dfea6da1..6b1131f1e 100644 --- a/relay/test/benchmark_test.go +++ b/relay/test/benchmark_test.go @@ -337,7 +337,7 @@ func runTurnDataTransfer(t *testing.T, testData []byte) time.Duration { func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) { t.Helper() // Dial TURN Server - addrStr := fmt.Sprintf("%s:%d", address, 443) + addrStr := net.JoinHostPort(address, "443") fac := logging.NewDefaultLoggerFactory() //fac.DefaultLogLevel = logging.LogLevelTrace diff --git a/relay/testec2/turn_allocator.go b/relay/testec2/turn_allocator.go index fd86208df..440f6222a 100644 --- a/relay/testec2/turn_allocator.go +++ b/relay/testec2/turn_allocator.go @@ -52,7 +52,7 @@ func AllocateTurnClient(serverAddr string) *TurnConn { func getTurnClient(address string, conn net.Conn) (*turn.Client, error) { // Dial TURN Server - addrStr := fmt.Sprintf("%s:%d", address, 443) + addrStr := net.JoinHostPort(address, "443") fac := logging.NewDefaultLoggerFactory() //fac.DefaultLogLevel = logging.LogLevelTrace diff --git a/route/route.go b/route/route.go index c724e7c7d..97b9721f6 100644 --- a/route/route.go +++ b/route/route.go @@ -20,6 +20,9 @@ const ( MaxMetric = 9999 // MaxNetIDChar Max Network Identifier MaxNetIDChar = 40 + + // V6ExitSuffix is appended to a v4 exit node NetID to form its v6 counterpart. + V6ExitSuffix = "-v6" ) const ( @@ -215,3 +218,61 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { return IPv4Network, masked, nil } + +var ( + v4Default = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + v6Default = netip.PrefixFrom(netip.IPv6Unspecified(), 0) +) + +// IsV4DefaultRoute reports whether p is the IPv4 default route (0.0.0.0/0). +func IsV4DefaultRoute(p netip.Prefix) bool { return p == v4Default } + +// IsV6DefaultRoute reports whether p is the IPv6 default route (::/0). +func IsV6DefaultRoute(p netip.Prefix) bool { return p == v6Default } + +// ExpandV6ExitPairs appends the paired "-v6" exit node NetID for any v4 exit +// node (0.0.0.0/0) in ids that has a matching v6 counterpart (::/0) in routesMap. +// It modifies and returns the input slice. +func ExpandV6ExitPairs(ids []NetID, routesMap map[NetID][]*Route) []NetID { + for _, id := range ids { + rt, ok := routesMap[id] + if !ok || len(rt) == 0 || !IsV4DefaultRoute(rt[0].Network) { + continue + } + v6ID := NetID(string(id) + V6ExitSuffix) + if v6Rt, ok := routesMap[v6ID]; ok && len(v6Rt) > 0 && IsV6DefaultRoute(v6Rt[0].Network) { + if !slices.Contains(ids, v6ID) { + ids = append(ids, v6ID) + } + } + } + return ids +} + +// V6ExitMergeSet scans routesMap and returns the set of v6 exit node NetIDs +// that should be hidden from the UI because they are paired with a v4 exit node. +// A v6 ID is paired when it has suffix "-v6", its route is ::/0, and the base +// name (without "-v6") exists with route 0.0.0.0/0. +func V6ExitMergeSet(routesMap map[NetID][]*Route) map[NetID]struct{} { + merged := make(map[NetID]struct{}) + for id, rt := range routesMap { + if len(rt) == 0 { + continue + } + name := string(id) + if !IsV6DefaultRoute(rt[0].Network) || !strings.HasSuffix(name, V6ExitSuffix) { + continue + } + baseName := NetID(strings.TrimSuffix(name, V6ExitSuffix)) + if baseRt, ok := routesMap[baseName]; ok && len(baseRt) > 0 && IsV4DefaultRoute(baseRt[0].Network) { + merged[id] = struct{}{} + } + } + return merged +} + +// HasV6ExitPair reports whether id has a paired v6 exit node in the merge set. +func HasV6ExitPair(id NetID, v6Merged map[NetID]struct{}) bool { + _, ok := v6Merged[NetID(string(id)+"-v6")] + return ok +} diff --git a/route/route_test.go b/route/route_test.go new file mode 100644 index 000000000..dab707ed3 --- /dev/null +++ b/route/route_test.go @@ -0,0 +1,108 @@ +package route + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExpandV6ExitPairs(t *testing.T) { + v4ExitRoute := &Route{Network: netip.MustParsePrefix("0.0.0.0/0")} + v6ExitRoute := &Route{Network: netip.MustParsePrefix("::/0")} + regularRoute := &Route{Network: netip.MustParsePrefix("10.0.0.0/8")} + + tests := []struct { + name string + ids []NetID + routesMap map[NetID][]*Route + expected []NetID + }{ + { + name: "v4 exit node with matching v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "v4 exit node without v6 pair", + ids: []NetID{"exit-node"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + }, + expected: []NetID{"exit-node"}, + }, + { + name: "regular route is not expanded", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {v6ExitRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "v6 already included is not duplicated", + ids: []NetID{"exit-node", "exit-node-v6"}, + routesMap: map[NetID][]*Route{ + "exit-node": {v4ExitRoute}, + "exit-node-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-node", "exit-node-v6"}, + }, + { + name: "multiple exit nodes expanded independently", + ids: []NetID{"exit-a", "exit-b"}, + routesMap: map[NetID][]*Route{ + "exit-a": {v4ExitRoute}, + "exit-a-v6": {v6ExitRoute}, + "exit-b": {v4ExitRoute}, + "exit-b-v6": {v6ExitRoute}, + }, + expected: []NetID{"exit-a", "exit-b", "exit-a-v6", "exit-b-v6"}, + }, + { + name: "v6 suffix but not exit node network", + ids: []NetID{"office"}, + routesMap: map[NetID][]*Route{ + "office": {regularRoute}, + "office-v6": {regularRoute}, + }, + expected: []NetID{"office"}, + }, + { + name: "user-chosen name for exit node with v6 pair", + ids: []NetID{"my-exit"}, + routesMap: map[NetID][]*Route{ + "my-exit": {v4ExitRoute}, + "my-exit-v6": {v6ExitRoute}, + }, + expected: []NetID{"my-exit", "my-exit-v6"}, + }, + { + name: "real-world management-generated IDs", + ids: []NetID{"0.0.0.0/0"}, + routesMap: map[NetID][]*Route{ + "0.0.0.0/0": {v4ExitRoute}, + "0.0.0.0/0-v6": {v6ExitRoute}, + }, + expected: []NetID{"0.0.0.0/0", "0.0.0.0/0-v6"}, + }, + { + name: "empty input", + ids: []NetID{}, + routesMap: map[NetID][]*Route{}, + expected: []NetID{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExpandV6ExitPairs(tt.ids, tt.routesMap) + assert.ElementsMatch(t, tt.expected, result) + }) + } +} diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 80625fe06..58895b7c2 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -937,8 +937,22 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { DisableFirewall: info.DisableFirewall, BlockLANAccess: info.BlockLANAccess, BlockInbound: info.BlockInbound, + DisableIPv6: info.DisableIPv6, LazyConnectionEnabled: info.LazyConnectionEnabled, }, + + Capabilities: peerCapabilities(*info), } } + +// peerCapabilities returns the capabilities this client supports. +func peerCapabilities(info system.Info) []proto.PeerCapability { + caps := []proto.PeerCapability{ + proto.PeerCapability_PeerCapabilitySourcePrefixes, + } + if !info.DisableIPv6 { + caps = append(caps, proto.PeerCapability_PeerCapabilityIPv6Overlay) + } + return caps +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 327e20614..8e6ee54cc 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -341,7 +341,11 @@ components: description: Allows to define a custom network range for the account in CIDR format type: string format: cidr - example: 100.64.0.0/16 + network_range_v6: + description: Allows to define a custom IPv6 network range for the account in CIDR format. + type: string + format: cidr + example: fd00:1234:5678::/64 peer_expose_enabled: description: Enables or disables peer expose. If enabled, peers can expose local services through the reverse proxy using the CLI. type: boolean @@ -377,6 +381,12 @@ components: type: boolean readOnly: true example: false + ipv6_enabled_groups: + description: List of group IDs whose peers receive IPv6 overlay addresses. Peers not in any of these groups will not be allocated an IPv6 address. New accounts default to the All group. + type: array + items: + type: string + example: ["ch8i4ug6lnn4g9hqv7m0"] required: - peer_login_expiration_enabled - peer_login_expiration @@ -776,6 +786,11 @@ components: type: string format: ipv4 example: 100.64.0.15 + ipv6: + description: Peer's IPv6 overlay address. Omitted if IPv6 is not enabled for the account. + type: string + format: ipv6 + example: "fd00:4e42:ab12::1" required: - name - ssh_enabled @@ -795,6 +810,11 @@ components: description: Peer's IP address type: string example: 10.64.0.1 + ipv6: + description: Peer's IPv6 overlay address + type: string + format: ipv6 + example: "fd00:4e42:ab12::1" connection_ip: description: Peer's public connection IP address type: string @@ -1013,6 +1033,10 @@ components: description: Peer's IP address type: string example: 10.64.0.1 + ipv6: + description: Peer's IPv6 overlay address + type: string + example: "fd00:4e42:ab12::1" dns_label: description: Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud type: string diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index dc916f81a..f8ea07be7 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1381,6 +1381,9 @@ type AccessiblePeer struct { // Ip Peer's IP address Ip string `json:"ip"` + // Ipv6 Peer's IPv6 overlay address + Ipv6 *string `json:"ipv6,omitempty"` + // LastSeen Last time peer connected to Netbird's management service LastSeen time.Time `json:"last_seen"` @@ -1465,6 +1468,9 @@ type AccountSettings struct { // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"` + // Ipv6EnabledGroups List of group IDs whose peers receive IPv6 overlay addresses. Peers not in any of these groups will not be allocated an IPv6 address. New accounts default to the All group. + Ipv6EnabledGroups *[]string `json:"ipv6_enabled_groups,omitempty"` + // JwtAllowGroups List of groups to which users are allowed access JwtAllowGroups *[]string `json:"jwt_allow_groups,omitempty"` @@ -1483,6 +1489,9 @@ type AccountSettings struct { // NetworkRange Allows to define a custom network range for the account in CIDR format NetworkRange *string `json:"network_range,omitempty"` + // NetworkRangeV6 Allows to define a custom IPv6 network range for the account in CIDR format. + NetworkRangeV6 *string `json:"network_range_v6,omitempty"` + // PeerExposeEnabled Enables or disables peer expose. If enabled, peers can expose local services through the reverse proxy using the CLI. PeerExposeEnabled bool `json:"peer_expose_enabled"` @@ -3141,6 +3150,9 @@ type Peer struct { // Ip Peer's IP address Ip string `json:"ip"` + // Ipv6 Peer's IPv6 overlay address + Ipv6 *string `json:"ipv6,omitempty"` + // KernelVersion Peer's operating system kernel version KernelVersion string `json:"kernel_version"` @@ -3232,6 +3244,9 @@ type PeerBatch struct { // Ip Peer's IP address Ip string `json:"ip"` + // Ipv6 Peer's IPv6 overlay address + Ipv6 *string `json:"ipv6,omitempty"` + // KernelVersion Peer's operating system kernel version KernelVersion string `json:"kernel_version"` @@ -3331,7 +3346,10 @@ type PeerRequest struct { InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` // Ip Peer's IP address - Ip *string `json:"ip,omitempty"` + Ip *string `json:"ip,omitempty"` + + // Ipv6 Peer's IPv6 overlay address. Omitted if IPv6 is not enabled for the account. + Ipv6 *string `json:"ipv6,omitempty"` LoginExpirationEnabled bool `json:"login_expiration_enabled"` Name string `json:"name"` SshEnabled bool `json:"ssh_enabled"` diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 604f9c793..13f4fbc8d 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -71,6 +71,59 @@ func (JobStatus) EnumDescriptor() ([]byte, []int) { return file_management_proto_rawDescGZIP(), []int{0} } +// PeerCapability represents a feature the client binary supports. +// Reported in PeerSystemMeta.capabilities on every login/sync. +type PeerCapability int32 + +const ( + PeerCapability_PeerCapabilityUnknown PeerCapability = 0 + // Client reads SourcePrefixes instead of the deprecated PeerIP string. + PeerCapability_PeerCapabilitySourcePrefixes PeerCapability = 1 + // Client handles IPv6 overlay addresses and firewall rules. + PeerCapability_PeerCapabilityIPv6Overlay PeerCapability = 2 +) + +// Enum value maps for PeerCapability. +var ( + PeerCapability_name = map[int32]string{ + 0: "PeerCapabilityUnknown", + 1: "PeerCapabilitySourcePrefixes", + 2: "PeerCapabilityIPv6Overlay", + } + PeerCapability_value = map[string]int32{ + "PeerCapabilityUnknown": 0, + "PeerCapabilitySourcePrefixes": 1, + "PeerCapabilityIPv6Overlay": 2, + } +) + +func (x PeerCapability) Enum() *PeerCapability { + p := new(PeerCapability) + *p = x + return p +} + +func (x PeerCapability) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (PeerCapability) Descriptor() protoreflect.EnumDescriptor { + return file_management_proto_enumTypes[1].Descriptor() +} + +func (PeerCapability) Type() protoreflect.EnumType { + return &file_management_proto_enumTypes[1] +} + +func (x PeerCapability) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use PeerCapability.Descriptor instead. +func (PeerCapability) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{1} +} + type RuleProtocol int32 const ( @@ -113,11 +166,11 @@ func (x RuleProtocol) String() string { } func (RuleProtocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[1].Descriptor() + return file_management_proto_enumTypes[2].Descriptor() } func (RuleProtocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[1] + return &file_management_proto_enumTypes[2] } func (x RuleProtocol) Number() protoreflect.EnumNumber { @@ -126,7 +179,7 @@ func (x RuleProtocol) Number() protoreflect.EnumNumber { // Deprecated: Use RuleProtocol.Descriptor instead. func (RuleProtocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{1} + return file_management_proto_rawDescGZIP(), []int{2} } type RuleDirection int32 @@ -159,11 +212,11 @@ func (x RuleDirection) String() string { } func (RuleDirection) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[2].Descriptor() + return file_management_proto_enumTypes[3].Descriptor() } func (RuleDirection) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[2] + return &file_management_proto_enumTypes[3] } func (x RuleDirection) Number() protoreflect.EnumNumber { @@ -172,7 +225,7 @@ func (x RuleDirection) Number() protoreflect.EnumNumber { // Deprecated: Use RuleDirection.Descriptor instead. func (RuleDirection) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{2} + return file_management_proto_rawDescGZIP(), []int{3} } type RuleAction int32 @@ -205,11 +258,11 @@ func (x RuleAction) String() string { } func (RuleAction) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[3].Descriptor() + return file_management_proto_enumTypes[4].Descriptor() } func (RuleAction) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[3] + return &file_management_proto_enumTypes[4] } func (x RuleAction) Number() protoreflect.EnumNumber { @@ -218,7 +271,7 @@ func (x RuleAction) Number() protoreflect.EnumNumber { // Deprecated: Use RuleAction.Descriptor instead. func (RuleAction) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{3} + return file_management_proto_rawDescGZIP(), []int{4} } type ExposeProtocol int32 @@ -260,11 +313,11 @@ func (x ExposeProtocol) String() string { } func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[4].Descriptor() + return file_management_proto_enumTypes[5].Descriptor() } func (ExposeProtocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[4] + return &file_management_proto_enumTypes[5] } func (x ExposeProtocol) Number() protoreflect.EnumNumber { @@ -273,7 +326,7 @@ func (x ExposeProtocol) Number() protoreflect.EnumNumber { // Deprecated: Use ExposeProtocol.Descriptor instead. func (ExposeProtocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{4} + return file_management_proto_rawDescGZIP(), []int{5} } type HostConfig_Protocol int32 @@ -315,11 +368,11 @@ func (x HostConfig_Protocol) String() string { } func (HostConfig_Protocol) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[5].Descriptor() + return file_management_proto_enumTypes[6].Descriptor() } func (HostConfig_Protocol) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[5] + return &file_management_proto_enumTypes[6] } func (x HostConfig_Protocol) Number() protoreflect.EnumNumber { @@ -358,11 +411,11 @@ func (x DeviceAuthorizationFlowProvider) String() string { } func (DeviceAuthorizationFlowProvider) Descriptor() protoreflect.EnumDescriptor { - return file_management_proto_enumTypes[6].Descriptor() + return file_management_proto_enumTypes[7].Descriptor() } func (DeviceAuthorizationFlowProvider) Type() protoreflect.EnumType { - return &file_management_proto_enumTypes[6] + return &file_management_proto_enumTypes[7] } func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { @@ -1201,6 +1254,7 @@ type Flags struct { EnableSSHLocalPortForwarding bool `protobuf:"varint,13,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"` EnableSSHRemotePortForwarding bool `protobuf:"varint,14,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` DisableSSHAuth bool `protobuf:"varint,15,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"` + DisableIPv6 bool `protobuf:"varint,16,opt,name=disableIPv6,proto3" json:"disableIPv6,omitempty"` } func (x *Flags) Reset() { @@ -1340,6 +1394,13 @@ func (x *Flags) GetDisableSSHAuth() bool { return false } +func (x *Flags) GetDisableIPv6() bool { + if x != nil { + return x.DisableIPv6 + } + return false +} + // PeerSystemMeta is machine meta data like OS and version. type PeerSystemMeta struct { state protoimpl.MessageState @@ -1363,6 +1424,7 @@ type PeerSystemMeta struct { Environment *Environment `protobuf:"bytes,15,opt,name=environment,proto3" json:"environment,omitempty"` Files []*File `protobuf:"bytes,16,rep,name=files,proto3" json:"files,omitempty"` Flags *Flags `protobuf:"bytes,17,opt,name=flags,proto3" json:"flags,omitempty"` + Capabilities []PeerCapability `protobuf:"varint,18,rep,packed,name=capabilities,proto3,enum=management.PeerCapability" json:"capabilities,omitempty"` } func (x *PeerSystemMeta) Reset() { @@ -1516,6 +1578,13 @@ func (x *PeerSystemMeta) GetFlags() *Flags { return nil } +func (x *PeerSystemMeta) GetCapabilities() []PeerCapability { + if x != nil { + return x.Capabilities + } + return nil +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -2163,6 +2232,8 @@ type PeerConfig struct { Mtu int32 `protobuf:"varint,7,opt,name=mtu,proto3" json:"mtu,omitempty"` // Auto-update config AutoUpdate *AutoUpdateSettings `protobuf:"bytes,8,opt,name=autoUpdate,proto3" json:"autoUpdate,omitempty"` + // IPv6 overlay address as compact bytes: 16 bytes IP + 1 byte prefix length. + AddressV6 []byte `protobuf:"bytes,9,opt,name=address_v6,json=addressV6,proto3" json:"address_v6,omitempty"` } func (x *PeerConfig) Reset() { @@ -2253,6 +2324,13 @@ func (x *PeerConfig) GetAutoUpdate() *AutoUpdateSettings { return nil } +func (x *PeerConfig) GetAddressV6() []byte { + if x != nil { + return x.AddressV6 + } + return nil +} + type AutoUpdateSettings struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3562,6 +3640,9 @@ type FirewallRule struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + // Use sourcePrefixes instead. + // + // Deprecated: Do not use. PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"` Direction RuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.RuleDirection" json:"Direction,omitempty"` Action RuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.RuleAction" json:"Action,omitempty"` @@ -3570,6 +3651,11 @@ type FirewallRule struct { PortInfo *PortInfo `protobuf:"bytes,6,opt,name=PortInfo,proto3" json:"PortInfo,omitempty"` // PolicyID is the ID of the policy that this rule belongs to PolicyID []byte `protobuf:"bytes,7,opt,name=PolicyID,proto3" json:"PolicyID,omitempty"` + // CustomProtocol is a custom protocol ID when Protocol is CUSTOM. + CustomProtocol uint32 `protobuf:"varint,8,opt,name=customProtocol,proto3" json:"customProtocol,omitempty"` + // Compact source IP prefixes for this rule, supersedes PeerIP. + // Each entry is 5 bytes (v4) or 17 bytes (v6): [IP bytes][1 byte prefix_len]. + SourcePrefixes [][]byte `protobuf:"bytes,9,rep,name=sourcePrefixes,proto3" json:"sourcePrefixes,omitempty"` } func (x *FirewallRule) Reset() { @@ -3604,6 +3690,7 @@ func (*FirewallRule) Descriptor() ([]byte, []int) { return file_management_proto_rawDescGZIP(), []int{41} } +// Deprecated: Do not use. func (x *FirewallRule) GetPeerIP() string { if x != nil { return x.PeerIP @@ -3653,6 +3740,20 @@ func (x *FirewallRule) GetPolicyID() []byte { return nil } +func (x *FirewallRule) GetCustomProtocol() uint32 { + if x != nil { + return x.CustomProtocol + } + return 0 +} + +func (x *FirewallRule) GetSourcePrefixes() [][]byte { + if x != nil { + return x.SourcePrefixes + } + return nil +} + type NetworkAddress struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -4542,7 +4643,7 @@ var file_management_proto_rawDesc = []byte{ 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x78, 0x69, 0x73, 0x74, 0x12, 0x2a, 0x0a, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x49, 0x73, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x49, 0x73, 0x52, - 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xbf, 0x05, 0x0a, 0x05, 0x46, 0x6c, 0x61, 0x67, 0x73, + 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x22, 0xe1, 0x05, 0x0a, 0x05, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, @@ -4586,551 +4687,571 @@ var file_management_proto_rawDesc = []byte{ 0x74, 0x65, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x26, 0x0a, 0x0e, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, - 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x22, 0xf2, 0x04, 0x0a, 0x0e, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1a, 0x0a, 0x08, 0x68, - 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, - 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x12, 0x16, 0x0a, 0x06, 0x6b, - 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6b, 0x65, 0x72, - 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, - 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, - 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, - 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x75, - 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, - 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x6b, 0x65, 0x72, - 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, - 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x46, 0x0a, - 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, - 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, - 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, - 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, - 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, - 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, - 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, - 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, - 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, - 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x52, - 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x26, 0x0a, 0x05, - 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x05, 0x66, - 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x18, 0x11, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x22, 0xb4, 0x01, - 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2a, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, - 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x52, 0x06, 0x43, 0x68, - 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x38, 0x0a, 0x09, 0x65, - 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, - 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, - 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, 0x70, 0x69, - 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, - 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xff, 0x01, 0x0a, 0x0d, 0x4e, 0x65, 0x74, - 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x05, 0x73, 0x74, - 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a, 0x05, 0x74, 0x75, 0x72, 0x6e, - 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, - 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74, 0x75, 0x72, 0x6e, 0x73, 0x12, - 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x12, - 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6c, 0x61, - 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x2a, - 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x22, 0x98, 0x01, 0x0a, 0x0a, 0x48, - 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, 0x0a, 0x08, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, - 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x02, - 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x44, - 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, - 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, - 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, - 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, 0x6f, 0x6b, - 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, 0x6f, 0x6b, - 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, - 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, - 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, - 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, 0x12, 0x2e, - 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x65, 0x78, 0x69, 0x74, - 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x24, - 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xa3, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x75, - 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x75, - 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, - 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6b, 0x65, - 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x6d, 0x61, - 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x12, 0x1c, 0x0a, 0x09, - 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, - 0x09, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0x7d, 0x0a, 0x13, 0x50, 0x72, - 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x68, - 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, - 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, - 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0xd3, 0x02, 0x0a, 0x0a, 0x50, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, - 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, - 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, - 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, - 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, - 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, 0x6e, 0x6e, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x10, 0x0a, - 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, 0x75, 0x12, - 0x3e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, 0x74, 0x69, - 0x6e, 0x67, 0x73, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x22, - 0x52, 0x0a, 0x12, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, - 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, - 0x22, 0x0a, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, - 0x61, 0x74, 0x65, 0x22, 0xe8, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, - 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, - 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, - 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, - 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, - 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, - 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x69, 0x73, 0x61, + 0x62, 0x6c, 0x65, 0x49, 0x50, 0x76, 0x36, 0x18, 0x10, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x64, + 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x49, 0x50, 0x76, 0x36, 0x22, 0xb2, 0x05, 0x0a, 0x0e, 0x50, + 0x65, 0x65, 0x72, 0x53, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1a, 0x0a, + 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x6f, 0x4f, + 0x53, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x67, 0x6f, 0x4f, 0x53, 0x12, 0x16, 0x0a, + 0x06, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6b, + 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x6c, 0x61, + 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x6c, 0x61, + 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x0e, 0x0a, 0x02, 0x4f, 0x53, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x02, 0x4f, 0x53, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, + 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, + 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, + 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x09, 0x75, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0d, 0x6b, + 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x09, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0d, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x0a, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4f, 0x53, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, + 0x46, 0x0a, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x65, 0x73, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x10, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, + 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0f, 0x73, 0x79, 0x73, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x4e, 0x75, 0x6d, 0x62, 0x65, + 0x72, 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, + 0x61, 0x6d, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x79, 0x73, 0x50, 0x72, + 0x6f, 0x64, 0x75, 0x63, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x28, 0x0a, 0x0f, 0x73, 0x79, 0x73, + 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, 0x72, 0x65, 0x72, 0x18, 0x0e, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0f, 0x73, 0x79, 0x73, 0x4d, 0x61, 0x6e, 0x75, 0x66, 0x61, 0x63, 0x74, 0x75, + 0x72, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, + 0x6e, 0x74, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, + 0x74, 0x52, 0x0b, 0x65, 0x6e, 0x76, 0x69, 0x72, 0x6f, 0x6e, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x26, + 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x10, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x52, + 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x05, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x18, + 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x52, 0x05, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x12, + 0x3e, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x18, + 0x12, 0x20, 0x03, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, + 0x79, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, + 0xb4, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x3f, 0x0a, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x0d, 0x6e, 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, + 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2a, 0x0a, 0x06, 0x43, 0x68, + 0x65, 0x63, 0x6b, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x52, 0x06, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x22, 0x79, 0x0a, 0x11, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, + 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x38, 0x0a, + 0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x65, 0x78, + 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x22, 0x07, 0x0a, 0x05, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0xff, 0x01, 0x0a, 0x0d, 0x4e, + 0x65, 0x74, 0x62, 0x69, 0x72, 0x64, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x05, + 0x73, 0x74, 0x75, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x05, 0x73, 0x74, 0x75, 0x6e, 0x73, 0x12, 0x35, 0x0a, 0x05, 0x74, 0x75, + 0x72, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, + 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x74, 0x75, 0x72, 0x6e, + 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, + 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x06, 0x73, 0x69, 0x67, 0x6e, 0x61, + 0x6c, 0x12, 0x2d, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, + 0x6c, 0x61, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x12, 0x2a, 0x0a, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x6c, 0x6f, 0x77, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x6f, 0x77, 0x22, 0x98, 0x01, 0x0a, + 0x0a, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, + 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x3b, 0x0a, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x3b, 0x0a, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x00, 0x12, + 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, + 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x03, 0x12, 0x08, 0x0a, + 0x04, 0x44, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x22, 0x6d, 0x0a, 0x0b, 0x52, 0x65, 0x6c, 0x61, 0x79, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x75, 0x72, 0x6c, 0x73, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, + 0x0a, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, + 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0xad, 0x02, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x22, 0x0a, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0e, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x12, 0x35, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x73, + 0x12, 0x2e, 0x0a, 0x12, 0x65, 0x78, 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x65, 0x78, + 0x69, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x64, 0x6e, 0x73, 0x43, 0x6f, 0x6c, 0x6c, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xa3, 0x01, 0x0a, 0x09, 0x4a, 0x57, 0x54, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x16, 0x0a, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x69, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, + 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x6b, 0x65, 0x79, 0x73, + 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, + 0x6b, 0x65, 0x79, 0x73, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, + 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x41, 0x67, 0x65, 0x12, 0x1c, + 0x0a, 0x09, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x09, 0x61, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x73, 0x22, 0x7d, 0x0a, 0x13, + 0x50, 0x72, 0x6f, 0x74, 0x65, 0x63, 0x74, 0x65, 0x64, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x36, 0x0a, 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, + 0x0a, 0x68, 0x6f, 0x73, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, + 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, + 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0xf2, 0x02, 0x0a, 0x0a, + 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x03, 0x64, 0x6e, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, + 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, + 0x48, 0x0a, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, + 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, + 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, + 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x4c, 0x61, 0x7a, + 0x79, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x15, 0x4c, 0x61, 0x7a, 0x79, 0x43, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x10, 0x0a, 0x03, 0x6d, 0x74, 0x75, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x74, + 0x75, 0x12, 0x3e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, + 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x74, + 0x74, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, + 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x76, 0x36, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x56, 0x36, + 0x22, 0x52, 0x0a, 0x12, 0x41, 0x75, 0x74, 0x6f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, + 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x12, 0x22, 0x0a, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x61, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x22, 0xe8, 0x05, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, - 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, - 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, - 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, - 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, + 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, + 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, + 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, + 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, - 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, 0x72, - 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0f, - 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, - 0x2d, 0x0a, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x13, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, - 0x48, 0x41, 0x75, 0x74, 0x68, 0x52, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, 0x68, 0x22, 0x82, - 0x02, 0x0a, 0x07, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x55, 0x73, - 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x12, 0x28, 0x0a, 0x0f, - 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, - 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x12, 0x4a, 0x0a, 0x0d, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, - 0x65, 0x5f, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, 0x75, - 0x74, 0x68, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, 0x45, - 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, - 0x72, 0x73, 0x1a, 0x5f, 0x0a, 0x11, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, - 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x34, 0x0a, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, - 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, - 0x02, 0x38, 0x01, 0x22, 0x2e, 0x0a, 0x12, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, - 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x69, 0x6e, 0x64, - 0x65, 0x78, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x69, 0x6e, 0x64, 0x65, - 0x78, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, - 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, - 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, - 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, - 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, - 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, 0x0a, - 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, - 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, - 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, - 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, 0x09, - 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, 0x54, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, - 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, - 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, - 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, - 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, - 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, - 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, - 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, - 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, - 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x22, 0xbc, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, - 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, - 0x44, 0x12, 0x26, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, - 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, - 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, - 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, - 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, - 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, 0x67, - 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, - 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, - 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, - 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, - 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, - 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, - 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, - 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, - 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, - 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x28, - 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, - 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb8, 0x01, 0x0a, 0x0a, 0x43, 0x75, 0x73, - 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, - 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, - 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, - 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x44, - 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, - 0x69, 0x76, 0x65, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, - 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, - 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, - 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, - 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, - 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, - 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, - 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, - 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, - 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, - 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, - 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, - 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, - 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, - 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, - 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, - 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, - 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, - 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, - 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, - 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, - 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, - 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, - 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, - 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, - 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, - 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, - 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, - 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, - 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, - 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, - 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, - 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, - 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, - 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, - 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, - 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, - 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, - 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, - 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, - 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, - 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, - 0x64, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, - 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x6f, - 0x72, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, - 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, - 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x75, 0x73, 0x65, 0x72, - 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x75, - 0x73, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, - 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, 0x65, 0x50, 0x72, 0x65, 0x66, - 0x69, 0x78, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x5f, 0x70, 0x6f, 0x72, - 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, - 0x6f, 0x72, 0x74, 0x22, 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, - 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, - 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x75, 0x72, 0x6c, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x55, 0x72, - 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, 0x0a, 0x12, 0x70, 0x6f, 0x72, - 0x74, 0x5f, 0x61, 0x75, 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, 0x41, 0x75, 0x74, 0x6f, 0x41, - 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x65, 0x77, - 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, - 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, - 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x2b, 0x0a, 0x11, - 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x14, 0x0a, 0x12, 0x53, 0x74, 0x6f, - 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, - 0x3a, 0x0a, 0x09, 0x4a, 0x6f, 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x12, 0x0a, 0x0e, - 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x10, 0x00, - 0x12, 0x0d, 0x0a, 0x09, 0x73, 0x75, 0x63, 0x63, 0x65, 0x65, 0x64, 0x65, 0x64, 0x10, 0x01, 0x12, - 0x0a, 0x0a, 0x06, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x10, 0x02, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, - 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, - 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, - 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, - 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, - 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, - 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, - 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, - 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, - 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x2a, - 0x63, 0x0a, 0x0e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x12, 0x0f, 0x0a, 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, 0x50, - 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, 0x54, 0x54, - 0x50, 0x53, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, - 0x43, 0x50, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x55, - 0x44, 0x50, 0x10, 0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x54, - 0x4c, 0x53, 0x10, 0x04, 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, + 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x44, 0x0a, 0x0f, 0x66, 0x6f, + 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0c, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x52, + 0x0f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x73, + 0x12, 0x2d, 0x0a, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, 0x68, 0x18, 0x0d, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x13, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, + 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x52, 0x07, 0x73, 0x73, 0x68, 0x41, 0x75, 0x74, 0x68, 0x22, + 0x82, 0x02, 0x0a, 0x07, 0x53, 0x53, 0x48, 0x41, 0x75, 0x74, 0x68, 0x12, 0x20, 0x0a, 0x0b, 0x55, + 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x55, 0x73, 0x65, 0x72, 0x49, 0x44, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x12, 0x28, 0x0a, + 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x65, 0x64, 0x55, 0x73, 0x65, 0x72, 0x73, 0x12, 0x4a, 0x0a, 0x0d, 0x6d, 0x61, 0x63, 0x68, 0x69, + 0x6e, 0x65, 0x5f, 0x75, 0x73, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x41, + 0x75, 0x74, 0x68, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, 0x65, 0x72, 0x73, + 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0c, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, + 0x65, 0x72, 0x73, 0x1a, 0x5f, 0x0a, 0x11, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, + 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x34, 0x0a, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, 0x73, + 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x3a, 0x02, 0x38, 0x01, 0x22, 0x2e, 0x0a, 0x12, 0x4d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x55, + 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x65, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x69, 0x6e, + 0x64, 0x65, 0x78, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x69, 0x6e, 0x64, + 0x65, 0x78, 0x65, 0x73, 0x22, 0xbb, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, + 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, + 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, + 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, + 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x22, + 0x0a, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x22, 0x7e, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x33, 0x0a, + 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4a, 0x57, + 0x54, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x6a, 0x77, 0x74, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, - 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, - 0x74, 0x79, 0x22, 0x00, 0x12, 0x47, 0x0a, 0x03, 0x4a, 0x6f, 0x62, 0x12, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x12, 0x4c, 0x0a, - 0x0c, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0b, 0x52, - 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0a, 0x53, 0x74, 0x6f, 0x70, - 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, + 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, + 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, + 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x22, 0xbc, 0x03, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x49, 0x44, 0x12, 0x26, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, + 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0c, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, + 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, + 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, + 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, + 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, + 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, + 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x55, 0x52, 0x4c, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, + 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x12, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, 0x61, + 0x67, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x46, 0x6c, + 0x61, 0x67, 0x22, 0x93, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, + 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, + 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, + 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, + 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, + 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, + 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xde, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, + 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, + 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, + 0x28, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, + 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xb8, 0x01, 0x0a, 0x0a, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, + 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, + 0x6f, 0x72, 0x64, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x4e, 0x6f, 0x6e, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x10, 0x4e, 0x6f, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x61, + 0x74, 0x69, 0x76, 0x65, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, + 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, + 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, + 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, + 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, + 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, + 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, + 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, + 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xfb, 0x02, 0x0a, 0x0c, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x1a, 0x0a, 0x06, 0x50, + 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x02, 0x18, 0x01, 0x52, + 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, + 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, + 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, + 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x26, 0x0a, 0x0e, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, + 0x65, 0x73, 0x18, 0x09, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0e, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, + 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x65, 0x73, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, + 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, + 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, + 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, + 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, + 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, + 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, + 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, + 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, + 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, + 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, + 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, + 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, + 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, + 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, + 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, + 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, + 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, + 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, + 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, + 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, + 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, + 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, + 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, + 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x8b, 0x02, 0x0a, 0x14, 0x45, + 0x78, 0x70, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x36, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, + 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, + 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x1f, 0x0a, + 0x0b, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x05, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x16, + 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x70, + 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x61, 0x6d, + 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x69, 0x73, 0x74, 0x65, + 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x6c, 0x69, + 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa1, 0x01, 0x0a, 0x15, 0x45, 0x78, 0x70, + 0x6f, 0x73, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x55, 0x72, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2c, + 0x0a, 0x12, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x61, 0x75, 0x74, 0x6f, 0x5f, 0x61, 0x73, 0x73, 0x69, + 0x67, 0x6e, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x70, 0x6f, 0x72, 0x74, + 0x41, 0x75, 0x74, 0x6f, 0x41, 0x73, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x22, 0x2c, 0x0a, 0x12, + 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x15, 0x0a, 0x13, 0x52, 0x65, + 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x2b, 0x0a, 0x11, 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x14, + 0x0a, 0x12, 0x53, 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x3a, 0x0a, 0x09, 0x4a, 0x6f, 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x12, 0x12, 0x0a, 0x0e, 0x75, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x5f, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x73, 0x75, 0x63, 0x63, 0x65, 0x65, 0x64, + 0x65, 0x64, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x10, 0x02, + 0x2a, 0x6c, 0x0a, 0x0e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, + 0x74, 0x79, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, + 0x6c, 0x69, 0x74, 0x79, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x10, 0x00, 0x12, 0x20, 0x0a, + 0x1c, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x53, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x65, 0x73, 0x10, 0x01, 0x12, + 0x1d, 0x0a, 0x19, 0x50, 0x65, 0x65, 0x72, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, + 0x79, 0x49, 0x50, 0x76, 0x36, 0x4f, 0x76, 0x65, 0x72, 0x6c, 0x61, 0x79, 0x10, 0x02, 0x2a, 0x4c, + 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, + 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, + 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, + 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, + 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, + 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, + 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, + 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, + 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, + 0x10, 0x01, 0x2a, 0x63, 0x0a, 0x0e, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0f, 0x0a, 0x0b, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, 0x48, + 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x45, 0x58, 0x50, 0x4f, 0x53, 0x45, 0x5f, + 0x48, 0x54, 0x54, 0x50, 0x53, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, + 0x45, 0x5f, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, + 0x45, 0x5f, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x0e, 0x0a, 0x0a, 0x45, 0x58, 0x50, 0x4f, 0x53, + 0x45, 0x5f, 0x54, 0x4c, 0x53, 0x10, 0x04, 0x32, 0xfd, 0x06, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, + 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, + 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, + 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, + 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, + 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, + 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x47, 0x0a, 0x03, 0x4a, 0x6f, 0x62, 0x12, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, + 0x12, 0x4c, 0x0a, 0x0c, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4b, + 0x0a, 0x0b, 0x52, 0x65, 0x6e, 0x65, 0x77, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x4a, 0x0a, 0x0a, 0x53, + 0x74, 0x6f, 0x70, 0x45, 0x78, 0x70, 0x6f, 0x73, 0x65, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -5145,166 +5266,168 @@ func file_management_proto_rawDescGZIP() []byte { return file_management_proto_rawDescData } -var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 7) +var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 8) var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 55) var file_management_proto_goTypes = []interface{}{ (JobStatus)(0), // 0: management.JobStatus - (RuleProtocol)(0), // 1: management.RuleProtocol - (RuleDirection)(0), // 2: management.RuleDirection - (RuleAction)(0), // 3: management.RuleAction - (ExposeProtocol)(0), // 4: management.ExposeProtocol - (HostConfig_Protocol)(0), // 5: management.HostConfig.Protocol - (DeviceAuthorizationFlowProvider)(0), // 6: management.DeviceAuthorizationFlow.provider - (*EncryptedMessage)(nil), // 7: management.EncryptedMessage - (*JobRequest)(nil), // 8: management.JobRequest - (*JobResponse)(nil), // 9: management.JobResponse - (*BundleParameters)(nil), // 10: management.BundleParameters - (*BundleResult)(nil), // 11: management.BundleResult - (*SyncRequest)(nil), // 12: management.SyncRequest - (*SyncResponse)(nil), // 13: management.SyncResponse - (*SyncMetaRequest)(nil), // 14: management.SyncMetaRequest - (*LoginRequest)(nil), // 15: management.LoginRequest - (*PeerKeys)(nil), // 16: management.PeerKeys - (*Environment)(nil), // 17: management.Environment - (*File)(nil), // 18: management.File - (*Flags)(nil), // 19: management.Flags - (*PeerSystemMeta)(nil), // 20: management.PeerSystemMeta - (*LoginResponse)(nil), // 21: management.LoginResponse - (*ServerKeyResponse)(nil), // 22: management.ServerKeyResponse - (*Empty)(nil), // 23: management.Empty - (*NetbirdConfig)(nil), // 24: management.NetbirdConfig - (*HostConfig)(nil), // 25: management.HostConfig - (*RelayConfig)(nil), // 26: management.RelayConfig - (*FlowConfig)(nil), // 27: management.FlowConfig - (*JWTConfig)(nil), // 28: management.JWTConfig - (*ProtectedHostConfig)(nil), // 29: management.ProtectedHostConfig - (*PeerConfig)(nil), // 30: management.PeerConfig - (*AutoUpdateSettings)(nil), // 31: management.AutoUpdateSettings - (*NetworkMap)(nil), // 32: management.NetworkMap - (*SSHAuth)(nil), // 33: management.SSHAuth - (*MachineUserIndexes)(nil), // 34: management.MachineUserIndexes - (*RemotePeerConfig)(nil), // 35: management.RemotePeerConfig - (*SSHConfig)(nil), // 36: management.SSHConfig - (*DeviceAuthorizationFlowRequest)(nil), // 37: management.DeviceAuthorizationFlowRequest - (*DeviceAuthorizationFlow)(nil), // 38: management.DeviceAuthorizationFlow - (*PKCEAuthorizationFlowRequest)(nil), // 39: management.PKCEAuthorizationFlowRequest - (*PKCEAuthorizationFlow)(nil), // 40: management.PKCEAuthorizationFlow - (*ProviderConfig)(nil), // 41: management.ProviderConfig - (*Route)(nil), // 42: management.Route - (*DNSConfig)(nil), // 43: management.DNSConfig - (*CustomZone)(nil), // 44: management.CustomZone - (*SimpleRecord)(nil), // 45: management.SimpleRecord - (*NameServerGroup)(nil), // 46: management.NameServerGroup - (*NameServer)(nil), // 47: management.NameServer - (*FirewallRule)(nil), // 48: management.FirewallRule - (*NetworkAddress)(nil), // 49: management.NetworkAddress - (*Checks)(nil), // 50: management.Checks - (*PortInfo)(nil), // 51: management.PortInfo - (*RouteFirewallRule)(nil), // 52: management.RouteFirewallRule - (*ForwardingRule)(nil), // 53: management.ForwardingRule - (*ExposeServiceRequest)(nil), // 54: management.ExposeServiceRequest - (*ExposeServiceResponse)(nil), // 55: management.ExposeServiceResponse - (*RenewExposeRequest)(nil), // 56: management.RenewExposeRequest - (*RenewExposeResponse)(nil), // 57: management.RenewExposeResponse - (*StopExposeRequest)(nil), // 58: management.StopExposeRequest - (*StopExposeResponse)(nil), // 59: management.StopExposeResponse - nil, // 60: management.SSHAuth.MachineUsersEntry - (*PortInfo_Range)(nil), // 61: management.PortInfo.Range - (*timestamppb.Timestamp)(nil), // 62: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 63: google.protobuf.Duration + (PeerCapability)(0), // 1: management.PeerCapability + (RuleProtocol)(0), // 2: management.RuleProtocol + (RuleDirection)(0), // 3: management.RuleDirection + (RuleAction)(0), // 4: management.RuleAction + (ExposeProtocol)(0), // 5: management.ExposeProtocol + (HostConfig_Protocol)(0), // 6: management.HostConfig.Protocol + (DeviceAuthorizationFlowProvider)(0), // 7: management.DeviceAuthorizationFlow.provider + (*EncryptedMessage)(nil), // 8: management.EncryptedMessage + (*JobRequest)(nil), // 9: management.JobRequest + (*JobResponse)(nil), // 10: management.JobResponse + (*BundleParameters)(nil), // 11: management.BundleParameters + (*BundleResult)(nil), // 12: management.BundleResult + (*SyncRequest)(nil), // 13: management.SyncRequest + (*SyncResponse)(nil), // 14: management.SyncResponse + (*SyncMetaRequest)(nil), // 15: management.SyncMetaRequest + (*LoginRequest)(nil), // 16: management.LoginRequest + (*PeerKeys)(nil), // 17: management.PeerKeys + (*Environment)(nil), // 18: management.Environment + (*File)(nil), // 19: management.File + (*Flags)(nil), // 20: management.Flags + (*PeerSystemMeta)(nil), // 21: management.PeerSystemMeta + (*LoginResponse)(nil), // 22: management.LoginResponse + (*ServerKeyResponse)(nil), // 23: management.ServerKeyResponse + (*Empty)(nil), // 24: management.Empty + (*NetbirdConfig)(nil), // 25: management.NetbirdConfig + (*HostConfig)(nil), // 26: management.HostConfig + (*RelayConfig)(nil), // 27: management.RelayConfig + (*FlowConfig)(nil), // 28: management.FlowConfig + (*JWTConfig)(nil), // 29: management.JWTConfig + (*ProtectedHostConfig)(nil), // 30: management.ProtectedHostConfig + (*PeerConfig)(nil), // 31: management.PeerConfig + (*AutoUpdateSettings)(nil), // 32: management.AutoUpdateSettings + (*NetworkMap)(nil), // 33: management.NetworkMap + (*SSHAuth)(nil), // 34: management.SSHAuth + (*MachineUserIndexes)(nil), // 35: management.MachineUserIndexes + (*RemotePeerConfig)(nil), // 36: management.RemotePeerConfig + (*SSHConfig)(nil), // 37: management.SSHConfig + (*DeviceAuthorizationFlowRequest)(nil), // 38: management.DeviceAuthorizationFlowRequest + (*DeviceAuthorizationFlow)(nil), // 39: management.DeviceAuthorizationFlow + (*PKCEAuthorizationFlowRequest)(nil), // 40: management.PKCEAuthorizationFlowRequest + (*PKCEAuthorizationFlow)(nil), // 41: management.PKCEAuthorizationFlow + (*ProviderConfig)(nil), // 42: management.ProviderConfig + (*Route)(nil), // 43: management.Route + (*DNSConfig)(nil), // 44: management.DNSConfig + (*CustomZone)(nil), // 45: management.CustomZone + (*SimpleRecord)(nil), // 46: management.SimpleRecord + (*NameServerGroup)(nil), // 47: management.NameServerGroup + (*NameServer)(nil), // 48: management.NameServer + (*FirewallRule)(nil), // 49: management.FirewallRule + (*NetworkAddress)(nil), // 50: management.NetworkAddress + (*Checks)(nil), // 51: management.Checks + (*PortInfo)(nil), // 52: management.PortInfo + (*RouteFirewallRule)(nil), // 53: management.RouteFirewallRule + (*ForwardingRule)(nil), // 54: management.ForwardingRule + (*ExposeServiceRequest)(nil), // 55: management.ExposeServiceRequest + (*ExposeServiceResponse)(nil), // 56: management.ExposeServiceResponse + (*RenewExposeRequest)(nil), // 57: management.RenewExposeRequest + (*RenewExposeResponse)(nil), // 58: management.RenewExposeResponse + (*StopExposeRequest)(nil), // 59: management.StopExposeRequest + (*StopExposeResponse)(nil), // 60: management.StopExposeResponse + nil, // 61: management.SSHAuth.MachineUsersEntry + (*PortInfo_Range)(nil), // 62: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 63: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 64: google.protobuf.Duration } var file_management_proto_depIdxs = []int32{ - 10, // 0: management.JobRequest.bundle:type_name -> management.BundleParameters + 11, // 0: management.JobRequest.bundle:type_name -> management.BundleParameters 0, // 1: management.JobResponse.status:type_name -> management.JobStatus - 11, // 2: management.JobResponse.bundle:type_name -> management.BundleResult - 20, // 3: management.SyncRequest.meta:type_name -> management.PeerSystemMeta - 24, // 4: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig - 30, // 5: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 35, // 6: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 32, // 7: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 50, // 8: management.SyncResponse.Checks:type_name -> management.Checks - 20, // 9: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta - 20, // 10: management.LoginRequest.meta:type_name -> management.PeerSystemMeta - 16, // 11: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 49, // 12: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress - 17, // 13: management.PeerSystemMeta.environment:type_name -> management.Environment - 18, // 14: management.PeerSystemMeta.files:type_name -> management.File - 19, // 15: management.PeerSystemMeta.flags:type_name -> management.Flags - 24, // 16: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig - 30, // 17: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 50, // 18: management.LoginResponse.Checks:type_name -> management.Checks - 62, // 19: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp - 25, // 20: management.NetbirdConfig.stuns:type_name -> management.HostConfig - 29, // 21: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig - 25, // 22: management.NetbirdConfig.signal:type_name -> management.HostConfig - 26, // 23: management.NetbirdConfig.relay:type_name -> management.RelayConfig - 27, // 24: management.NetbirdConfig.flow:type_name -> management.FlowConfig - 5, // 25: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 63, // 26: management.FlowConfig.interval:type_name -> google.protobuf.Duration - 25, // 27: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 36, // 28: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 31, // 29: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings - 30, // 30: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 35, // 31: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 42, // 32: management.NetworkMap.Routes:type_name -> management.Route - 43, // 33: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 35, // 34: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 48, // 35: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 52, // 36: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule - 53, // 37: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule - 33, // 38: management.NetworkMap.sshAuth:type_name -> management.SSHAuth - 60, // 39: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry - 36, // 40: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 28, // 41: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig - 6, // 42: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 41, // 43: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 41, // 44: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 46, // 45: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 44, // 46: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 45, // 47: management.CustomZone.Records:type_name -> management.SimpleRecord - 47, // 48: management.NameServerGroup.NameServers:type_name -> management.NameServer - 2, // 49: management.FirewallRule.Direction:type_name -> management.RuleDirection - 3, // 50: management.FirewallRule.Action:type_name -> management.RuleAction - 1, // 51: management.FirewallRule.Protocol:type_name -> management.RuleProtocol - 51, // 52: management.FirewallRule.PortInfo:type_name -> management.PortInfo - 61, // 53: management.PortInfo.range:type_name -> management.PortInfo.Range - 3, // 54: management.RouteFirewallRule.action:type_name -> management.RuleAction - 1, // 55: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol - 51, // 56: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo - 1, // 57: management.ForwardingRule.protocol:type_name -> management.RuleProtocol - 51, // 58: management.ForwardingRule.destinationPort:type_name -> management.PortInfo - 51, // 59: management.ForwardingRule.translatedPort:type_name -> management.PortInfo - 4, // 60: management.ExposeServiceRequest.protocol:type_name -> management.ExposeProtocol - 34, // 61: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes - 7, // 62: management.ManagementService.Login:input_type -> management.EncryptedMessage - 7, // 63: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 23, // 64: management.ManagementService.GetServerKey:input_type -> management.Empty - 23, // 65: management.ManagementService.isHealthy:input_type -> management.Empty - 7, // 66: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 7, // 67: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 7, // 68: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 7, // 69: management.ManagementService.Logout:input_type -> management.EncryptedMessage - 7, // 70: management.ManagementService.Job:input_type -> management.EncryptedMessage - 7, // 71: management.ManagementService.CreateExpose:input_type -> management.EncryptedMessage - 7, // 72: management.ManagementService.RenewExpose:input_type -> management.EncryptedMessage - 7, // 73: management.ManagementService.StopExpose:input_type -> management.EncryptedMessage - 7, // 74: management.ManagementService.Login:output_type -> management.EncryptedMessage - 7, // 75: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 22, // 76: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 23, // 77: management.ManagementService.isHealthy:output_type -> management.Empty - 7, // 78: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 7, // 79: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 23, // 80: management.ManagementService.SyncMeta:output_type -> management.Empty - 23, // 81: management.ManagementService.Logout:output_type -> management.Empty - 7, // 82: management.ManagementService.Job:output_type -> management.EncryptedMessage - 7, // 83: management.ManagementService.CreateExpose:output_type -> management.EncryptedMessage - 7, // 84: management.ManagementService.RenewExpose:output_type -> management.EncryptedMessage - 7, // 85: management.ManagementService.StopExpose:output_type -> management.EncryptedMessage - 74, // [74:86] is the sub-list for method output_type - 62, // [62:74] is the sub-list for method input_type - 62, // [62:62] is the sub-list for extension type_name - 62, // [62:62] is the sub-list for extension extendee - 0, // [0:62] is the sub-list for field type_name + 12, // 2: management.JobResponse.bundle:type_name -> management.BundleResult + 21, // 3: management.SyncRequest.meta:type_name -> management.PeerSystemMeta + 25, // 4: management.SyncResponse.netbirdConfig:type_name -> management.NetbirdConfig + 31, // 5: management.SyncResponse.peerConfig:type_name -> management.PeerConfig + 36, // 6: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 33, // 7: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 51, // 8: management.SyncResponse.Checks:type_name -> management.Checks + 21, // 9: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta + 21, // 10: management.LoginRequest.meta:type_name -> management.PeerSystemMeta + 17, // 11: management.LoginRequest.peerKeys:type_name -> management.PeerKeys + 50, // 12: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 18, // 13: management.PeerSystemMeta.environment:type_name -> management.Environment + 19, // 14: management.PeerSystemMeta.files:type_name -> management.File + 20, // 15: management.PeerSystemMeta.flags:type_name -> management.Flags + 1, // 16: management.PeerSystemMeta.capabilities:type_name -> management.PeerCapability + 25, // 17: management.LoginResponse.netbirdConfig:type_name -> management.NetbirdConfig + 31, // 18: management.LoginResponse.peerConfig:type_name -> management.PeerConfig + 51, // 19: management.LoginResponse.Checks:type_name -> management.Checks + 63, // 20: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 26, // 21: management.NetbirdConfig.stuns:type_name -> management.HostConfig + 30, // 22: management.NetbirdConfig.turns:type_name -> management.ProtectedHostConfig + 26, // 23: management.NetbirdConfig.signal:type_name -> management.HostConfig + 27, // 24: management.NetbirdConfig.relay:type_name -> management.RelayConfig + 28, // 25: management.NetbirdConfig.flow:type_name -> management.FlowConfig + 6, // 26: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 64, // 27: management.FlowConfig.interval:type_name -> google.protobuf.Duration + 26, // 28: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 37, // 29: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 32, // 30: management.PeerConfig.autoUpdate:type_name -> management.AutoUpdateSettings + 31, // 31: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 36, // 32: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 43, // 33: management.NetworkMap.Routes:type_name -> management.Route + 44, // 34: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 36, // 35: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 49, // 36: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 53, // 37: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 54, // 38: management.NetworkMap.forwardingRules:type_name -> management.ForwardingRule + 34, // 39: management.NetworkMap.sshAuth:type_name -> management.SSHAuth + 61, // 40: management.SSHAuth.machine_users:type_name -> management.SSHAuth.MachineUsersEntry + 37, // 41: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 29, // 42: management.SSHConfig.jwtConfig:type_name -> management.JWTConfig + 7, // 43: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 42, // 44: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 42, // 45: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 47, // 46: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 45, // 47: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 46, // 48: management.CustomZone.Records:type_name -> management.SimpleRecord + 48, // 49: management.NameServerGroup.NameServers:type_name -> management.NameServer + 3, // 50: management.FirewallRule.Direction:type_name -> management.RuleDirection + 4, // 51: management.FirewallRule.Action:type_name -> management.RuleAction + 2, // 52: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 52, // 53: management.FirewallRule.PortInfo:type_name -> management.PortInfo + 62, // 54: management.PortInfo.range:type_name -> management.PortInfo.Range + 4, // 55: management.RouteFirewallRule.action:type_name -> management.RuleAction + 2, // 56: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 52, // 57: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 2, // 58: management.ForwardingRule.protocol:type_name -> management.RuleProtocol + 52, // 59: management.ForwardingRule.destinationPort:type_name -> management.PortInfo + 52, // 60: management.ForwardingRule.translatedPort:type_name -> management.PortInfo + 5, // 61: management.ExposeServiceRequest.protocol:type_name -> management.ExposeProtocol + 35, // 62: management.SSHAuth.MachineUsersEntry.value:type_name -> management.MachineUserIndexes + 8, // 63: management.ManagementService.Login:input_type -> management.EncryptedMessage + 8, // 64: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 24, // 65: management.ManagementService.GetServerKey:input_type -> management.Empty + 24, // 66: management.ManagementService.isHealthy:input_type -> management.Empty + 8, // 67: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 8, // 68: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 8, // 69: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 8, // 70: management.ManagementService.Logout:input_type -> management.EncryptedMessage + 8, // 71: management.ManagementService.Job:input_type -> management.EncryptedMessage + 8, // 72: management.ManagementService.CreateExpose:input_type -> management.EncryptedMessage + 8, // 73: management.ManagementService.RenewExpose:input_type -> management.EncryptedMessage + 8, // 74: management.ManagementService.StopExpose:input_type -> management.EncryptedMessage + 8, // 75: management.ManagementService.Login:output_type -> management.EncryptedMessage + 8, // 76: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 23, // 77: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 24, // 78: management.ManagementService.isHealthy:output_type -> management.Empty + 8, // 79: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 8, // 80: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 24, // 81: management.ManagementService.SyncMeta:output_type -> management.Empty + 24, // 82: management.ManagementService.Logout:output_type -> management.Empty + 8, // 83: management.ManagementService.Job:output_type -> management.EncryptedMessage + 8, // 84: management.ManagementService.CreateExpose:output_type -> management.EncryptedMessage + 8, // 85: management.ManagementService.RenewExpose:output_type -> management.EncryptedMessage + 8, // 86: management.ManagementService.StopExpose:output_type -> management.EncryptedMessage + 75, // [75:87] is the sub-list for method output_type + 63, // [63:75] is the sub-list for method input_type + 63, // [63:63] is the sub-list for extension type_name + 63, // [63:63] is the sub-list for extension extendee + 0, // [0:63] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -5977,7 +6100,7 @@ func file_management_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, - NumEnums: 7, + NumEnums: 8, NumMessages: 55, NumExtensions: 0, NumServices: 1, diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 70a530679..461a614fe 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -200,6 +200,18 @@ message Flags { bool enableSSHLocalPortForwarding = 13; bool enableSSHRemotePortForwarding = 14; bool disableSSHAuth = 15; + + bool disableIPv6 = 16; +} + +// PeerCapability represents a feature the client binary supports. +// Reported in PeerSystemMeta.capabilities on every login/sync. +enum PeerCapability { + PeerCapabilityUnknown = 0; + // Client reads SourcePrefixes instead of the deprecated PeerIP string. + PeerCapabilitySourcePrefixes = 1; + // Client handles IPv6 overlay addresses and firewall rules. + PeerCapabilityIPv6Overlay = 2; } // PeerSystemMeta is machine meta data like OS and version. @@ -221,6 +233,8 @@ message PeerSystemMeta { Environment environment = 15; repeated File files = 16; Flags flags = 17; + + repeated PeerCapability capabilities = 18; } message LoginResponse { @@ -335,6 +349,9 @@ message PeerConfig { // Auto-update config AutoUpdateSettings autoUpdate = 8; + + // IPv6 overlay address as compact bytes: 16 bytes IP + 1 byte prefix length. + bytes address_v6 = 9; } message AutoUpdateSettings { @@ -567,7 +584,8 @@ enum RuleAction { // FirewallRule represents a firewall rule message FirewallRule { - string PeerIP = 1; + // Use sourcePrefixes instead. + string PeerIP = 1 [deprecated = true]; RuleDirection Direction = 2; RuleAction Action = 3; RuleProtocol Protocol = 4; @@ -576,6 +594,13 @@ message FirewallRule { // PolicyID is the ID of the policy that this rule belongs to bytes PolicyID = 7; + + // CustomProtocol is a custom protocol ID when Protocol is CUSTOM. + uint32 customProtocol = 8; + + // Compact source IP prefixes for this rule, supersedes PeerIP. + // Each entry is 5 bytes (v4) or 17 bytes (v6): [IP bytes][1 byte prefix_len]. + repeated bytes sourcePrefixes = 9; } message NetworkAddress { diff --git a/shared/netiputil/compact.go b/shared/netiputil/compact.go new file mode 100644 index 000000000..0cd2b8a20 --- /dev/null +++ b/shared/netiputil/compact.go @@ -0,0 +1,78 @@ +// Package netiputil provides compact binary encoding for IP prefixes used in +// the management proto wire format. +// +// Format: [IP bytes][1 byte prefix_len] +// - IPv4: 5 bytes total (4 IP + 1 prefix_len, 0-32) +// - IPv6: 17 bytes total (16 IP + 1 prefix_len, 0-128) +// +// Address family is determined by length: 5 = v4, 17 = v6. +package netiputil + +import ( + "fmt" + "net/netip" +) + +// EncodePrefix encodes a netip.Prefix into compact bytes. +// The address is always unmapped before encoding. +func EncodePrefix(p netip.Prefix) ([]byte, error) { + addr := p.Addr().Unmap() + bits := p.Bits() + if addr.Is4() && bits > 32 { + return nil, fmt.Errorf("invalid prefix length %d for IPv4 address %s (max 32)", bits, addr) + } + return append(addr.AsSlice(), byte(bits)), nil +} + +// DecodePrefix decodes compact bytes into a netip.Prefix. +func DecodePrefix(b []byte) (netip.Prefix, error) { + switch len(b) { + case 5: + var ip4 [4]byte + copy(ip4[:], b) + bits := int(b[len(b)-1]) + if bits > 32 { + return netip.Prefix{}, fmt.Errorf("invalid IPv4 prefix length %d (max 32)", bits) + } + return netip.PrefixFrom(netip.AddrFrom4(ip4), bits), nil + case 17: + var ip6 [16]byte + copy(ip6[:], b) + addr := netip.AddrFrom16(ip6).Unmap() + bits := int(b[len(b)-1]) + if addr.Is4() { + if bits > 32 { + return netip.Prefix{}, fmt.Errorf("invalid prefix length %d for v4-mapped address (max 32)", bits) + } + } else if bits > 128 { + return netip.Prefix{}, fmt.Errorf("invalid IPv6 prefix length %d (max 128)", bits) + } + return netip.PrefixFrom(addr, bits), nil + default: + return netip.Prefix{}, fmt.Errorf("invalid compact prefix length %d (expected 5 or 17)", len(b)) + } +} + +// EncodeAddr encodes a netip.Addr into compact prefix bytes with a host prefix +// length (/32 for v4, /128 for v6). The address is always unmapped before encoding. +func EncodeAddr(a netip.Addr) []byte { + a = a.Unmap() + bits := 128 + if a.Is4() { + bits = 32 + } + // Host prefix lengths are always valid for the address family, so error is impossible. + b, _ := EncodePrefix(netip.PrefixFrom(a, bits)) + return b +} + +// DecodeAddr decodes compact prefix bytes and returns only the address, +// discarding the prefix length. Useful when the prefix length is implied +// (e.g. peer overlay IPs are always /32 or /128). +func DecodeAddr(b []byte) (netip.Addr, error) { + p, err := DecodePrefix(b) + if err != nil { + return netip.Addr{}, err + } + return p.Addr(), nil +} diff --git a/shared/netiputil/compact_test.go b/shared/netiputil/compact_test.go new file mode 100644 index 000000000..1e7c7ed82 --- /dev/null +++ b/shared/netiputil/compact_test.go @@ -0,0 +1,175 @@ +package netiputil + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeDecodePrefix(t *testing.T) { + tests := []struct { + name string + prefix string + size int + }{ + { + name: "v4 host", + prefix: "100.64.0.1/32", + size: 5, + }, + { + name: "v4 network", + prefix: "10.0.0.0/8", + size: 5, + }, + { + name: "v4 default", + prefix: "0.0.0.0/0", + size: 5, + }, + { + name: "v6 host", + prefix: "fd00::1/128", + size: 17, + }, + { + name: "v6 network", + prefix: "fd00:1234:5678::/48", + size: 17, + }, + { + name: "v6 default", + prefix: "::/0", + size: 17, + }, + { + name: "v4 /16 overlay", + prefix: "100.64.0.1/16", + size: 5, + }, + { + name: "v6 /64 overlay", + prefix: "fd00::abcd:1/64", + size: 17, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := netip.MustParsePrefix(tt.prefix) + b, err := EncodePrefix(p) + require.NoError(t, err) + assert.Equal(t, tt.size, len(b), "encoded size") + + decoded, err := DecodePrefix(b) + require.NoError(t, err) + assert.Equal(t, p, decoded) + }) + } +} + +func TestEncodePrefixUnmaps(t *testing.T) { + // v4-mapped v6 address should encode as v4 + mapped := netip.MustParsePrefix("::ffff:10.1.2.3/32") + b, err := EncodePrefix(mapped) + require.NoError(t, err) + assert.Equal(t, 5, len(b), "v4-mapped should encode as 5 bytes") + + decoded, err := DecodePrefix(b) + require.NoError(t, err) + assert.Equal(t, netip.MustParsePrefix("10.1.2.3/32"), decoded) +} + +func TestEncodePrefixUnmapsRejectsInvalidBits(t *testing.T) { + // v4-mapped v6 with bits > 32 should return an error + mapped128 := netip.MustParsePrefix("::ffff:10.1.2.3/128") + _, err := EncodePrefix(mapped128) + require.Error(t, err) + + // v4-mapped v6 with bits=96 should also return an error + mapped96 := netip.MustParsePrefix("::ffff:10.0.0.0/96") + _, err = EncodePrefix(mapped96) + require.Error(t, err) + + // v4-mapped v6 with bits=32 should succeed + mapped32 := netip.MustParsePrefix("::ffff:10.1.2.3/32") + b, err := EncodePrefix(mapped32) + require.NoError(t, err) + assert.Equal(t, 5, len(b), "v4-mapped should encode as 5 bytes") + + decoded, err := DecodePrefix(b) + require.NoError(t, err) + assert.Equal(t, netip.MustParsePrefix("10.1.2.3/32"), decoded) +} + +func TestDecodeAddr(t *testing.T) { + v4 := netip.MustParseAddr("100.64.0.5") + b := EncodeAddr(v4) + assert.Equal(t, 5, len(b)) + + got, err := DecodeAddr(b) + require.NoError(t, err) + assert.Equal(t, v4, got) + + v6 := netip.MustParseAddr("fd00::1") + b = EncodeAddr(v6) + assert.Equal(t, 17, len(b)) + + got, err = DecodeAddr(b) + require.NoError(t, err) + assert.Equal(t, v6, got) +} + +func TestDecodePrefixInvalidLength(t *testing.T) { + _, err := DecodePrefix([]byte{1, 2, 3}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid compact prefix length 3") + + _, err = DecodePrefix(nil) + assert.Error(t, err) + + _, err = DecodePrefix([]byte{}) + assert.Error(t, err) +} + +func TestDecodePrefixInvalidBits(t *testing.T) { + // v4 with bits > 32 + b := []byte{10, 0, 0, 1, 33} + _, err := DecodePrefix(b) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid IPv4 prefix length 33") + + // v6 with bits > 128 + b = make([]byte, 17) + b[0] = 0xfd + b[16] = 129 + _, err = DecodePrefix(b) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid IPv6 prefix length 129") +} + +func TestDecodePrefixUnmapsV6Input(t *testing.T) { + addr := netip.MustParseAddr("::ffff:192.168.1.1") + + // v4-mapped v6 with bits > 32 should return an error + raw := addr.As16() + bInvalid := make([]byte, 17) + copy(bInvalid, raw[:]) + bInvalid[16] = 128 + + _, err := DecodePrefix(bInvalid) + require.Error(t, err, "v4-mapped address with /128 prefix should be rejected") + assert.Contains(t, err.Error(), "invalid prefix length") + + // v4-mapped v6 with valid /32 should decode and unmap correctly + bValid := make([]byte, 17) + copy(bValid, raw[:]) + bValid[16] = 32 + + decoded, err := DecodePrefix(bValid) + require.NoError(t, err) + assert.True(t, decoded.Addr().Is4(), "should be unmapped to v4") + assert.Equal(t, netip.MustParsePrefix("192.168.1.1/32"), decoded) +} diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 602803b19..86f6f178d 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -49,7 +49,7 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, InitialPacketSize: nbRelay.QUICInitialPacketSize, } - udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0}) if err != nil { log.Errorf("failed to listen on UDP: %s", err) return nil, err diff --git a/upload-server/server/s3_test.go b/upload-server/server/s3_test.go index 7ab1bb379..a72356409 100644 --- a/upload-server/server/s3_test.go +++ b/upload-server/server/s3_test.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "net" "net/http" "net/http/httptest" "runtime" @@ -52,7 +53,7 @@ func Test_S3HandlerGetUploadURL(t *testing.T) { hostIP, err := c.Host(ctx) require.NoError(t, err) - awsEndpoint := "http://" + hostIP + ":" + mappedPort.Port() + awsEndpoint := "http://" + net.JoinHostPort(hostIP, mappedPort.Port()) t.Setenv("AWS_REGION", awsRegion) t.Setenv("AWS_ENDPOINT_URL", awsEndpoint) diff --git a/util/capture/text.go b/util/capture/text.go index b44bd0cad..fbb26654e 100644 --- a/util/capture/text.go +++ b/util/capture/text.go @@ -4,7 +4,9 @@ import ( "encoding/binary" "fmt" "io" + "net" "net/netip" + "strconv" "strings" "time" @@ -91,9 +93,9 @@ func (tw *TextWriter) writeTCP(timeStr string, dir Direction, info *packetInfo, } if !tw.verbose { - _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s] length %d%s\n", + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s [%s] length %d%s\n", timeStr, tag(dir, "TCP"), - info.srcIP, info.srcPort, info.dstIP, info.dstPort, + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), flags, plen, annotation) if err != nil { return err @@ -125,9 +127,9 @@ func (tw *TextWriter) writeTCP(timeStr string, dir Direction, info *packetInfo, verbose := tw.verboseIP(data, info.family) - _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s]%s%s, win %d%s, length %d%s%s\n", + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s [%s]%s%s, win %d%s, length %d%s%s\n", timeStr, tag(dir, "TCP"), - info.srcIP, info.srcPort, info.dstIP, info.dstPort, + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), flags, seqStr, ackStr, tcp.Window, opts, plen, annotation, verbose) if err != nil { return err @@ -153,9 +155,9 @@ func (tw *TextWriter) writeUDP(timeStr string, dir Direction, info *packetInfo, if tw.verbose { verbose = tw.verboseIP(data, info.family) } - _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d %s%s\n", + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s%s\n", timeStr, tag(dir, "UDP"), - info.srcIP, info.srcPort, info.dstIP, info.dstPort, + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), s, verbose) return err } @@ -165,9 +167,9 @@ func (tw *TextWriter) writeUDP(timeStr string, dir Direction, info *packetInfo, if tw.verbose { verbose = tw.verboseIP(data, info.family) } - _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d%s\n", + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s length %d%s\n", timeStr, tag(dir, "UDP"), - info.srcIP, info.srcPort, info.dstIP, info.dstPort, + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), plen, verbose) if err != nil { return err @@ -216,9 +218,9 @@ func (tw *TextWriter) writeICMPv6(timeStr string, dir Direction, info *packetInf } func (tw *TextWriter) writeFallback(timeStr string, dir Direction, proto string, info *packetInfo, data []byte) error { - _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d\n", + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s length %d\n", timeStr, tag(dir, proto), - info.srcIP, info.srcPort, info.dstIP, info.dstPort, + net.JoinHostPort(info.srcIP.String(), strconv.Itoa(int(info.srcPort))), net.JoinHostPort(info.dstIP.String(), strconv.Itoa(int(info.dstPort))), len(data)-info.hdrLen) return err } From 39eac377e425dc7efd6872eeb50e0e494c0f25d8 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 7 May 2026 15:55:59 +0200 Subject: [PATCH 09/27] [management] add update reason to buffered calls (#6103) --- .../controllers/network_map/controller/controller.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 36de950e9..590773dda 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -221,9 +221,13 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin return nil } -func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error { +func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName()) + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) + } + bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) b := bufUpd.(*bufferUpdate) @@ -570,7 +574,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t } func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { - err := c.bufferSendUpdateAccountPeers(ctx, accountID) + err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate}) if err != nil { log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) } @@ -580,7 +584,7 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs) - return c.bufferSendUpdateAccountPeers(ctx, accountID) + return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate}) } func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { @@ -616,7 +620,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI c.peersUpdateManager.CloseChannel(ctx, peerID) } - return c.bufferSendUpdateAccountPeers(ctx, accountID) + return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete}) } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) From 7da94a4956af76f7187733aa488e9d20a0f62202 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 7 May 2026 16:16:48 +0200 Subject: [PATCH 10/27] [misc] Update CONTRIBUTING.md (#6076) --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index efc7d9460..960cd30e9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ There are many ways that you can contribute: - Sharing use cases in slack or Reddit - Bug fix or feature enhancement -If you haven't already, join our slack workspace [here](https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A), we would love to discuss topics that need community contribution and enhancements to existing features. +If you haven't already, join our slack workspace [here](https://docs.netbird.io/slack-url), we would love to discuss topics that need community contribution and enhancements to existing features. ## Contents From e89aad09f5c2ae2205720ec42f665ad948af8f66 Mon Sep 17 00:00:00 2001 From: Nicolas Frati Date: Fri, 8 May 2026 16:31:20 +0200 Subject: [PATCH 11/27] [management] Enable MFA for local users (#5804) * wip: totp for local users * fix providers not getting populated * polished UI and fix post_login_redirect_uri * fix: make sure logout is only prompted from oidc flow Signed-off-by: jnfrati * update templates Signed-off-by: jnfrati * deps: update dex dependency Signed-off-by: jnfrati * fix qube issues Signed-off-by: jnfrati * replace window with globalThis on home html Signed-off-by: jnfrati * fixed coderabbit comments Signed-off-by: jnfrati * debug * remove unused config and rename totp issuer * deps: update dex reference to latest * add dashboard post logout redirect uri to embedded config * implemented api for mfa configuration * update docs and config parsing * catch error on idp manager init mfa * fix tests * Add remember me for MFA * Add cookie encryption and session share between tabs * fixed logout showing non actionable error and session cookie encription key * fixed missing mfa settings on sql query for account * fix code index for mfa activity --------- Signed-off-by: jnfrati Co-authored-by: braginini --- combined/cmd/config.go | 36 ++-- combined/config.yaml.example | 10 + go.mod | 44 +++-- go.sum | 97 +++++++--- idp/dex/config.go | 161 ++++++++++++++++- idp/dex/provider.go | 79 +++++++- idp/dex/provider_test.go | 26 +++ idp/dex/web/templates/home.html | 12 ++ idp/dex/web/templates/logout.html | 14 ++ idp/dex/web/templates/password.html | 2 + idp/dex/web/templates/totp_verify.html | 44 +++++ idp/dex/web/templates/webauthn_verify.html | 12 ++ management/internals/server/modules.go | 32 +++- management/server/account.go | 26 +++ management/server/activity/codes.go | 8 + .../handlers/accounts/accounts_handler.go | 4 + .../accounts/accounts_handler_test.go | 6 + management/server/idp/embedded.go | 171 +++++++++++++++++- management/server/idp/embedded_test.go | 67 +++++++ management/server/store/sql_store.go | 15 +- management/server/types/settings.go | 5 + shared/management/http/api/openapi.yml | 4 + shared/management/http/api/types.gen.go | 3 + 23 files changed, 791 insertions(+), 87 deletions(-) create mode 100644 idp/dex/web/templates/home.html create mode 100644 idp/dex/web/templates/logout.html create mode 100644 idp/dex/web/templates/totp_verify.html create mode 100644 idp/dex/web/templates/webauthn_verify.html diff --git a/combined/cmd/config.go b/combined/cmd/config.go index 9959f7a56..fe350e52a 100644 --- a/combined/cmd/config.go +++ b/combined/cmd/config.go @@ -133,13 +133,18 @@ type ManagementConfig struct { // AuthConfig contains authentication/identity provider settings type AuthConfig struct { - Issuer string `yaml:"issuer"` - LocalAuthDisabled bool `yaml:"localAuthDisabled"` - SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"` - Storage AuthStorageConfig `yaml:"storage"` - DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"` - CLIRedirectURIs []string `yaml:"cliRedirectURIs"` - Owner *AuthOwnerConfig `yaml:"owner,omitempty"` + Issuer string `yaml:"issuer"` + LocalAuthDisabled bool `yaml:"localAuthDisabled"` + SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"` + MfaSessionMaxLifetime string `yaml:"mfaSessionMaxLifetime"` + MfaSessionIdleTimeout string `yaml:"mfaSessionIdleTimeout"` + MfaSessionRememberMe bool `yaml:"mfaSessionRememberMe"` + SessionCookieEncryptionKey string `yaml:"sessionCookieEncryptionKey"` + Storage AuthStorageConfig `yaml:"storage"` + DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"` + CLIRedirectURIs []string `yaml:"cliRedirectURIs"` + Owner *AuthOwnerConfig `yaml:"owner,omitempty"` + DashboardPostLogoutRedirectURIs []string `yaml:"dashboardPostLogoutRedirectURIs"` } // AuthStorageConfig contains auth storage settings @@ -581,10 +586,14 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb } cfg := &idp.EmbeddedIdPConfig{ - Enabled: true, - Issuer: mgmt.Auth.Issuer, - LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled, - SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled, + Enabled: true, + Issuer: mgmt.Auth.Issuer, + LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled, + SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled, + MfaSessionMaxLifetime: mgmt.Auth.MfaSessionMaxLifetime, + MfaSessionIdleTimeout: mgmt.Auth.MfaSessionIdleTimeout, + MfaSessionRememberMe: mgmt.Auth.MfaSessionRememberMe, + SessionCookieEncryptionKey: mgmt.Auth.SessionCookieEncryptionKey, Storage: idp.EmbeddedStorageConfig{ Type: authStorageType, Config: idp.EmbeddedStorageTypeConfig{ @@ -592,8 +601,9 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb DSN: authStorageDSN, }, }, - DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs, - CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs, + DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs, + CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs, + DashboardPostLogoutRedirectURIs: mgmt.Auth.DashboardPostLogoutRedirectURIs, } if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" { diff --git a/combined/config.yaml.example b/combined/config.yaml.example index af85b0477..66bc71703 100644 --- a/combined/config.yaml.example +++ b/combined/config.yaml.example @@ -86,6 +86,13 @@ server: issuer: "https://example.com/oauth2" localAuthDisabled: false signKeyRefreshEnabled: false + # MFA session settings (applies when TOTP is enabled for an account) + # mfaSessionMaxLifetime: "24h" # Max duration for an MFA session from creation + # mfaSessionIdleTimeout: "1h" # MFA session expires after this idle period + # mfaSessionRememberMe: false # Pre-check "remember me" on login so the MFA session persists across tabs/restarts + # Optional AES key for encrypting embedded IdP session cookies. Can also be set via NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY. + # Must be 16/24/32 raw bytes or base64-encoded to one of those lengths (for example: openssl rand -hex 16). + # sessionCookieEncryptionKey: "" # OAuth2 redirect URIs for dashboard dashboardRedirectURIs: - "https://app.example.com/nb-auth" @@ -93,6 +100,9 @@ server: # OAuth2 redirect URIs for CLI cliRedirectURIs: - "http://localhost:53000/" + # OAuth2 post-logout redirect URIs for dashboard (RP-initiated logout) + # dashboardPostLogoutRedirectURIs: + # - "https://app.example.com/" # Optional initial admin user # owner: # email: "admin@example.com" diff --git a/go.mod b/go.mod index bc4e8af15..84aeab941 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/onsi/gomega v1.27.6 github.com/rs/cors v1.8.0 github.com/sirupsen/logrus v1.9.4 - github.com/spf13/cobra v1.10.1 + github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.9 github.com/vishvananda/netlink v1.3.1 golang.org/x/crypto v0.50.0 @@ -41,11 +41,11 @@ require ( github.com/cilium/ebpf v0.15.0 github.com/coder/websocket v1.8.14 github.com/coreos/go-iptables v0.7.0 - github.com/coreos/go-oidc/v3 v3.14.1 + github.com/coreos/go-oidc/v3 v3.18.0 github.com/creack/pty v1.1.24 github.com/crowdsecurity/crowdsec v1.7.7 github.com/crowdsecurity/go-cs-bouncer v0.0.21 - github.com/dexidp/dex v0.0.0-00010101000000-000000000000 + github.com/dexidp/dex v2.13.0+incompatible github.com/dexidp/dex/api/v2 v2.4.0 github.com/ebitengine/purego v0.8.4 github.com/eko/gocache/lib/v4 v4.2.0 @@ -53,9 +53,9 @@ require ( github.com/eko/gocache/store/redis/v4 v4.2.2 github.com/fsnotify/fsnotify v1.9.0 github.com/gliderlabs/ssh v0.3.8 - github.com/go-jose/go-jose/v4 v4.1.3 + github.com/go-jose/go-jose/v4 v4.1.4 github.com/godbus/dbus/v5 v5.1.0 - github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/gopacket v1.1.19 @@ -113,7 +113,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.64.0 go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/sdk/metric v1.43.0 - go.uber.org/mock v0.5.2 + go.uber.org/mock v0.6.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b @@ -141,7 +141,7 @@ require ( filippo.io/edwards25519 v1.1.1 // indirect github.com/AppsFlyer/go-sundheit v0.6.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect - github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect + github.com/Azure/go-ntlmssp v0.1.0 // indirect github.com/BurntSushi/toml v1.5.0 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.3.0 // indirect @@ -168,6 +168,7 @@ require ( github.com/aws/smithy-go v1.23.0 // indirect github.com/beevik/etree v1.6.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -183,6 +184,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v1.1.1 // indirect + github.com/fxamacker/cbor/v2 v2.9.1 // indirect github.com/fyne-io/gl-js v0.2.0 // indirect github.com/fyne-io/glfw-js v0.3.0 // indirect github.com/fyne-io/image v0.1.1 // indirect @@ -190,7 +192,7 @@ require ( github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect - github.com/go-ldap/ldap/v3 v3.4.12 // indirect + github.com/go-ldap/ldap/v3 v3.4.13 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect @@ -206,11 +208,15 @@ require ( github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-text/render v0.2.0 // indirect github.com/go-text/typesetting v0.2.1 // indirect + github.com/go-viper/mapstructure/v2 v2.5.0 // indirect + github.com/go-webauthn/webauthn v0.16.4 // indirect + github.com/go-webauthn/x v0.2.3 // indirect github.com/goccy/go-yaml v1.18.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/go-tpm v0.9.8 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect github.com/googleapis/gax-go/v2 v2.21.0 // indirect @@ -218,7 +224,13 @@ require ( github.com/hack-pad/go-indexeddb v0.3.2 // indirect github.com/hack-pad/safejs v0.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-retryablehttp v0.7.8 // indirect + github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect + github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect + github.com/hashicorp/go-sockaddr v1.0.7 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect + github.com/hashicorp/hcl v1.0.1-vault-7 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/huin/goupnp v1.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -238,13 +250,13 @@ require ( github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/koron/go-ssdp v0.0.4 // indirect github.com/kr/fs v0.1.0 // indirect - github.com/lib/pq v1.10.9 // indirect + github.com/lib/pq v1.12.3 // indirect github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect - github.com/mattn/go-sqlite3 v1.14.32 // indirect + github.com/mattn/go-sqlite3 v1.14.42 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect @@ -265,8 +277,10 @@ require ( github.com/nxadm/tail v1.4.11 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect + github.com/openbao/openbao/api/v2 v2.5.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/philhofer/fwd v1.2.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/dtls/v3 v3.0.9 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect @@ -275,11 +289,13 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect + github.com/pquerna/otp v1.5.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/otlptranslator v1.0.0 // indirect github.com/prometheus/procfs v0.19.2 // indirect github.com/russellhaering/goxmldsig v1.6.0 // indirect + github.com/ryanuber/go-glob v1.0.0 // indirect github.com/rymdport/portal v0.4.2 // indirect github.com/shirou/gopsutil/v4 v4.25.8 // indirect github.com/shoenig/go-m1cpu v0.2.1 // indirect @@ -288,11 +304,13 @@ require ( github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/tinylib/msgp v1.6.3 // indirect github.com/tklauser/go-sysconf v0.3.15 // indirect github.com/tklauser/numcpus v0.10.0 // indirect github.com/vishvananda/netns v0.0.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect + github.com/x448/float16 v0.8.4 // indirect github.com/yuin/goldmark v1.7.8 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.mongodb.org/mongo-driver v1.17.9 // indirect @@ -319,10 +337,12 @@ replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-202 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 -replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 +replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 -replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 +replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2 + +replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2 replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 diff --git a/go.sum b/go.sum index d54dc01e6..851d1ce66 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3R cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA= +codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw= cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= @@ -23,8 +25,8 @@ github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= -github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= -github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= +github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A= +github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= @@ -91,6 +93,8 @@ github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sL github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -117,8 +121,8 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= -github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk= -github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU= +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= @@ -130,14 +134,10 @@ github.com/crowdsecurity/go-cs-bouncer v0.0.21 h1:arPz0VtdVSaz+auOSfHythzkZVLyy1 github.com/crowdsecurity/go-cs-bouncer v0.0.21/go.mod h1:4JiH0XXA4KKnnWThItUpe5+heJHWzsLOSA2IWJqUDBA= github.com/crowdsecurity/go-cs-lib v0.0.25 h1:Ov6VPW9yV+OPsbAIQk1iTkEWhwkpaG0v3lrBzeqjzj4= github.com/crowdsecurity/go-cs-lib v0.0.25/go.mod h1:X0GMJY2CxdA1S09SpuqIKaWQsvRGxXmecUp9cP599dE= -github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0= -github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A= -github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -156,6 +156,8 @@ github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEM github.com/eko/gocache/store/go_cache/v4 v4.2.2/go.mod h1:T9zkHokzr8K9EiC7RfMbDg6HSwaV6rv3UdcNu13SGcA= github.com/eko/gocache/store/redis/v4 v4.2.2 h1:Thw31fzGuH3WzJywsdbMivOmP550D6JS7GDHhvCJPA0= github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D7PBTIJ4+pzvk0ykz0w= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -171,6 +173,8 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4 github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ= +github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs= github.com/fyne-io/gl-js v0.2.0/go.mod h1:ZcepK8vmOYLu96JoxbCKJy2ybr+g1pTnaBDdl7c3ajI= github.com/fyne-io/glfw-js v0.3.0 h1:d8k2+Y7l+zy2pc7wlGRyPfTgZoqDf3AI4G+2zOWhWUk= @@ -189,10 +193,10 @@ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= -github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= -github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= -github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ= +github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -229,12 +233,20 @@ github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI6 github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc= github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU= github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg82V8= github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= +github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= +github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4CyGN+1Q= +github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4= +github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA= +github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= @@ -243,8 +255,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= -github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -276,6 +288,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= +github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo= +github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba h1:qJEJcuLzH5KDR0gKc0zcktin6KSAwL7+jWKBYceddTc= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba/go.mod h1:EFYHy8/1y2KfgTAsx7Luu7NGhoxtuVHnNo8jE7FikKc= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= @@ -308,15 +324,29 @@ github.com/hack-pad/safejs v0.1.0/go.mod h1:HdS+bKF1NrE72VoXZeWzxFOVQVUSqZJAG0xN github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= +github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw= +github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I= +github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= @@ -387,8 +417,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= @@ -406,9 +436,13 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= -github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.42 h1:MigqEP4ZmHw3aIdIT7T+9TLa90Z6smwcthx+Azv4Cgo= +github.com/mattn/go-sqlite3 v1.14.42/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= @@ -451,8 +485,10 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U= -github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= +github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2 h1:AP7OM/JnTogod3rVcLsMuilSG94kWQCr3z6R4rfVXnc= +github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2/go.mod h1:+trSlzHNmdJGvz0oLEyyiuaPstUeD7YO6B3Fx9nyziY= +github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2 h1:HEEGJPsVw7/p7SEL3HWP4vaInxHo8OJSEaOkHpUAk+M= +github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M= github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= @@ -489,6 +525,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= +github.com/openbao/openbao/api/v2 v2.5.1 h1:Br79D6L20SbAa5P7xqENxmvv8LyI4HoKosPy7klhn4o= +github.com/openbao/openbao/api/v2 v2.5.1/go.mod h1:Dh5un77tqGgMbmlVEqjqN+8/dMyUohnkaQVg/wXW0Ig= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -501,6 +539,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0 github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs= github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= +github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= @@ -542,6 +582,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= @@ -565,6 +607,8 @@ github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/russellhaering/goxmldsig v1.6.0 h1:8fdWXEPh2k/NZNQBPFNoVfS3JmzS4ZprY/sAOpKQLks= github.com/russellhaering/goxmldsig v1.6.0/go.mod h1:TrnaquDcYxWXfJrOjeMBTX4mLBeYAqaHEyUeWPxZlBM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU= github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU= @@ -587,8 +631,8 @@ github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= -github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= @@ -628,6 +672,8 @@ github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0 github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm1Sa314o= github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40= github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo= +github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s= +github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4= github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4= @@ -646,6 +692,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -690,14 +738,15 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= -go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4= goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/idp/dex/config.go b/idp/dex/config.go index e686233ad..56ed998c2 100644 --- a/idp/dex/config.go +++ b/idp/dex/config.go @@ -51,6 +51,70 @@ type YAMLConfig struct { // StaticPasswords cause the server use this list of passwords rather than // querying the storage. StaticPasswords []Password `yaml:"staticPasswords" json:"staticPasswords"` + + // Sessions holds authentication session configuration. + // Requires DEX_SESSIONS_ENABLED=true feature flag. + Sessions *Sessions `yaml:"sessions" json:"sessions"` + + // MFA holds multi-factor authentication configuration. + MFA MFAConfig `yaml:"mfa" json:"mfa"` +} + +type Sessions struct { + // CookieName is the name of the session cookie. Defaults to "dex_session". + CookieName string `yaml:"cookieName" json:"cookieName"` + // AbsoluteLifetime is the maximum session lifetime from creation. Defaults to "24h". + AbsoluteLifetime string `yaml:"absoluteLifetime" json:"absoluteLifetime"` + // ValidIfNotUsedFor is the idle timeout. Defaults to "1h". + ValidIfNotUsedFor string `yaml:"validIfNotUsedFor" json:"validIfNotUsedFor"` + // RememberMeCheckedByDefault controls the default state of the "remember me" checkbox. + RememberMeCheckedByDefault *bool `yaml:"rememberMeCheckedByDefault" json:"rememberMeCheckedByDefault"` + // CookieEncryptionKey is the AES key for encrypting session cookies. + // Must be 16, 24, or 32 bytes for AES-128, AES-192, or AES-256. + // If empty, cookies are not encrypted. + CookieEncryptionKey string `yaml:"cookieEncryptionKey" json:"cookieEncryptionKey"` + // SSOSharedWithDefault is the default SSO sharing policy for clients without explicit ssoSharedWith. + // "all" = share with all clients, "none" = share with no one (default: "none"). + SSOSharedWithDefault string `yaml:"ssoSharedWithDefault" json:"ssoSharedWithDefault"` +} + +type MFAConfig struct { + Authenticators []MFAAuthenticator `yaml:"authenticators" json:"authenticators"` +} + +type MFAAuthenticator struct { + ID string `yaml:"id" json:"id"` + Type string `yaml:"type" json:"type"` + Config map[string]interface{} `yaml:"config" json:"config"` + + ConnectorTypes []string `yaml:"connectorTypes" json:"connectorTypes"` +} + +type TOTPConfig struct { + Issuer string `yaml:"issuer" json:"issuer"` +} + +// WebAuthnConfig holds configuration for a WebAuthn authenticator. +type WebAuthnConfig struct { + // RPDisplayName is the human-readable relying party name shown in the browser + // dialog during key registration and authentication (e.g., "My Company SSO"). + RPDisplayName string `yaml:"rpDisplayName" json:"rpDisplayName"` + // RPID is the relying party identifier — must match the domain in the browser + // address bar. If empty, derived from the issuer URL hostname. + // Example: "auth.example.com" + RPID string `yaml:"rpID" json:"rpID"` + // RPOrigins is the list of allowed origins for WebAuthn ceremonies. + // If empty, derived from the issuer URL (scheme + host). + // Example: ["https://auth.example.com"] + RPOrigins []string `yaml:"rpOrigins" json:"rpOrigins"` + // AttestationPreference controls what attestation data the authenticator should provide: + // "none" — don't request attestation (simpler, more private) + // "indirect" — authenticator may anonymize attestation (default) + // "direct" — request full attestation (for enterprise key model verification) + AttestationPreference string `yaml:"attestationPreference" json:"attestationPreference"` + // Timeout is the duration allowed for the browser WebAuthn ceremony + // (registration or login). Defaults to "60s". + Timeout string `yaml:"timeout" json:"timeout"` } // Web is the config format for the HTTP server. @@ -116,7 +180,6 @@ type Storage struct { Config map[string]interface{} `yaml:"config" json:"config"` } -// Password represents a static user configuration type Password storage.Password func (p *Password) UnmarshalYAML(node *yaml.Node) error { @@ -429,9 +492,98 @@ func (c *YAMLConfig) Validate() error { if !c.EnablePasswordDB && len(c.StaticPasswords) != 0 { return fmt.Errorf("cannot specify static passwords without enabling password db") } + return nil } +func buildTotpConfig(auth MFAAuthenticator) (*server.TOTPProvider, error) { + data, err := json.Marshal(auth.Config) + if err != nil { + return nil, fmt.Errorf("failed to marshal TOTP config id: %s - %w", auth.ID, err) + } + + var cfg TOTPConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse TOTP config id: %s - %w", auth.ID, err) + } + + return server.NewTOTPProvider(cfg.Issuer, auth.ConnectorTypes), nil +} + +func buildWebAuthnConfig(auth MFAAuthenticator, issuerURL string) (*server.WebAuthnProvider, error) { + data, err := json.Marshal(auth.Config) + if err != nil { + return nil, fmt.Errorf("failed to marshal WebAuthn config id: %s - %w", auth.ID, err) + } + + var cfg WebAuthnConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse WebAuthn config id: %s - %w", auth.ID, err) + } + + provider, err := server.NewWebAuthnProvider(cfg.RPDisplayName, cfg.RPID, cfg.RPOrigins, + cfg.AttestationPreference, cfg.Timeout, issuerURL, auth.ConnectorTypes) + if err != nil { + return nil, fmt.Errorf("failed to create WebAuthn provider id: %s - err: %w", auth.ID, err) + } + + return provider, nil +} + +func buildMFAProviders(authenticators []MFAAuthenticator, issuerURL string, logger *slog.Logger) map[string]server.MFAProvider { + if len(authenticators) == 0 { + return nil + } + + providers := make(map[string]server.MFAProvider, len(authenticators)) + for _, auth := range authenticators { + switch auth.Type { + case "TOTP": + provider, err := buildTotpConfig(auth) + if err != nil { + logger.Error("failed to parse TOTP config", "id", auth.ID, "err", err) + continue + } + providers[auth.ID] = provider + logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type) + case "WebAuthn": + provider, err := buildWebAuthnConfig(auth, issuerURL) + if err != nil { + logger.Error("failed to parse WebAuthn config", "id", auth.ID, "err", err) + continue + } + providers[auth.ID] = provider + logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type) + default: + logger.Error("unknown MFA authenticator type, skipping", "id", auth.ID, "type", auth.Type) + } + } + return providers +} + +func buildSessionsConfig(sessions *Sessions) *server.SessionConfig { + if sessions == nil { + return nil + } + + if sessions.RememberMeCheckedByDefault == nil { + defaultRememberMeCheckedByDefault := false + sessions.RememberMeCheckedByDefault = &defaultRememberMeCheckedByDefault + } + + absoluteLifetime, _ := parseDuration(sessions.AbsoluteLifetime) + validIfNotUsedFor, _ := parseDuration(sessions.ValidIfNotUsedFor) + + return &server.SessionConfig{ + CookieEncryptionKey: []byte(sessions.CookieEncryptionKey), + CookieName: sessions.CookieName, + AbsoluteLifetime: absoluteLifetime, + ValidIfNotUsedFor: validIfNotUsedFor, + RememberMeCheckedByDefault: *sessions.RememberMeCheckedByDefault, + SSOSharedWithDefault: sessions.SSOSharedWithDefault, + } +} + // ToServerConfig converts YAMLConfig to dex server.Config func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) server.Config { cfg := server.Config{ @@ -448,6 +600,8 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s Dir: c.Frontend.Dir, Extra: c.Frontend.Extra, }, + SessionConfig: buildSessionsConfig(c.Sessions), + MFAProviders: buildMFAProviders(c.MFA.Authenticators, c.Issuer, logger), } // Use embedded NetBird-styled templates if no custom dir specified @@ -460,11 +614,6 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s } // Apply expiry settings - if c.Expiry.SigningKeys != "" { - if d, err := parseDuration(c.Expiry.SigningKeys); err == nil { - cfg.RotateKeysAfter = d - } - } if c.Expiry.IDTokens != "" { if d, err := parseDuration(c.Expiry.IDTokens); err == nil { cfg.IDTokensValidFor = d diff --git a/idp/dex/provider.go b/idp/dex/provider.go index 24aed1b99..526d6a17a 100644 --- a/idp/dex/provider.go +++ b/idp/dex/provider.go @@ -18,6 +18,7 @@ import ( dexapi "github.com/dexidp/dex/api/v2" "github.com/dexidp/dex/server" + "github.com/dexidp/dex/server/signer" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/sql" jose "github.com/go-jose/go-jose/v4" @@ -70,7 +71,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) { logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) // Ensure data directory exists - if err := os.MkdirAll(config.DataDir, 0700); err != nil { + if err := os.MkdirAll(config.DataDir, 0o700); err != nil { return nil, fmt.Errorf("failed to create data directory: %w", err) } @@ -101,6 +102,15 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) { return nil, fmt.Errorf("failed to create refresh token policy: %w", err) } + localSignerConfig := signer.LocalConfig{ + KeysRotationPeriod: "6h", + } + + localSigner, err := localSignerConfig.Open(ctx, stor, 24*time.Hour, time.Now, logger) + if err != nil { + return nil, fmt.Errorf("failed to create local signer: %w", err) + } + // Build Dex server config - use Dex's types directly dexConfig := server.Config{ Issuer: issuer, @@ -110,12 +120,12 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) { ContinueOnConnectorFailure: true, Logger: logger, PrometheusRegistry: prometheus.NewRegistry(), - RotateKeysAfter: 6 * time.Hour, IDTokensValidFor: 24 * time.Hour, RefreshTokenPolicy: refreshPolicy, Web: server.WebConfig{ Issuer: "NetBird", }, + Signer: localSigner, } dexSrv, err := server.NewServer(ctx, dexConfig) @@ -167,6 +177,14 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider return nil, fmt.Errorf("failed to create refresh token policy: %w", err) } + localSigner, err := getSigner(ctx, stor, yamlConfig, logger) + if err != nil { + stor.Close() + return nil, fmt.Errorf("failed to create local signer: %w", err) + } + + dexConfig.Signer = localSigner + dexSrv, err := server.NewServer(ctx, dexConfig) if err != nil { stor.Close() @@ -182,6 +200,32 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider }, nil } +func getSigner(ctx context.Context, stor storage.Storage, yamlConfig *YAMLConfig, logger *slog.Logger) (signer.Signer, error) { + // Parse expiry durations + idTokensValidFor := 24 * time.Hour // default + if yamlConfig.Expiry.IDTokens != "" { + var err error + idTokensValidFor, err = parseDuration(yamlConfig.Expiry.IDTokens) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for id token expiry: %v", yamlConfig.Expiry.IDTokens, err) + } + } + + localSignerConfig := &signer.LocalConfig{ + KeysRotationPeriod: "720h", // 30 Days + } + + if yamlConfig.Expiry.SigningKeys != "" { + if _, err := parseDuration(yamlConfig.Expiry.SigningKeys); err != nil { + return nil, fmt.Errorf("invalid config value %q for signing key expiry: %v", yamlConfig.Expiry.SigningKeys, err) + } + + localSignerConfig.KeysRotationPeriod = yamlConfig.Expiry.SigningKeys + } + + return localSignerConfig.Open(ctx, stor, idTokensValidFor, time.Now, logger) +} + // initializeStorage sets up connectors, passwords, and clients in storage func initializeStorage(ctx context.Context, stor storage.Storage, cfg *YAMLConfig) error { if cfg.EnablePasswordDB { @@ -241,6 +285,8 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st old.RedirectURIs = client.RedirectURIs old.Name = client.Name old.Public = client.Public + old.PostLogoutRedirectURIs = client.PostLogoutRedirectURIs + old.MFAChain = client.MFAChain return old, nil }); err != nil { return fmt.Errorf("failed to update client %s: %w", client.ID, err) @@ -253,9 +299,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config { cfg := yamlConfig.ToServerConfig(stor, logger) cfg.PrometheusRegistry = prometheus.NewRegistry() - if cfg.RotateKeysAfter == 0 { - cfg.RotateKeysAfter = 24 * 30 * time.Hour - } if cfg.IDTokensValidFor == 0 { cfg.IDTokensValidFor = 24 * time.Hour } @@ -450,10 +493,34 @@ func (p *Provider) Storage() storage.Storage { return p.storage } +// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients. +// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it. +func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error { + for _, clientID := range clientIDs { + if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) { + old.MFAChain = mfaChain + return old, nil + }); err != nil { + return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err) + } + } + return nil +} + // Handler returns the Dex server as an http.Handler for embedding in another server. // The handler expects requests with path prefix "/oauth2/". func (p *Provider) Handler() http.Handler { - return p.dexServer + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Dex's /logout endpoint requires id_token_hint for RP-initiated logout with + // post_logout_redirect_uri. If the dashboard calls logout without one, avoid + // rendering Dex's non-actionable Bad Request page and send the user home. + if strings.HasSuffix(r.URL.Path, "/logout") && r.FormValue("id_token_hint") == "" { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + + p.dexServer.ServeHTTP(w, r) + }) } // CreateUser creates a new user with the given email, username, and password. diff --git a/idp/dex/provider_test.go b/idp/dex/provider_test.go index 4ed89fd2e..88828fbbb 100644 --- a/idp/dex/provider_test.go +++ b/idp/dex/provider_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "log/slog" + "net/http" + "net/http/httptest" "os" "path/filepath" "testing" @@ -144,6 +146,30 @@ func TestEncodeDexUserID_MatchesDexFormat(t *testing.T) { assert.Equal(t, knownEncodedID, reEncoded) } +func TestHandlerRedirectsLogoutWithoutIDTokenHint(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "dex-logout-handler-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + provider, err := NewProvider(ctx, &Config{ + Issuer: "http://localhost:5556/oauth2", + Port: 5556, + DataDir: tmpDir, + }) + require.NoError(t, err) + defer func() { _ = provider.Stop(ctx) }() + + req := httptest.NewRequest(http.MethodGet, "/oauth2/logout?post_logout_redirect_uri=https://example.com", nil) + rec := httptest.NewRecorder() + + provider.Handler().ServeHTTP(rec, req) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.Equal(t, "/", rec.Header().Get("Location")) +} + func TestCreateUserInTempDB(t *testing.T) { ctx := context.Background() diff --git a/idp/dex/web/templates/home.html b/idp/dex/web/templates/home.html new file mode 100644 index 000000000..be7c938ae --- /dev/null +++ b/idp/dex/web/templates/home.html @@ -0,0 +1,12 @@ +{{ template "header.html" . }} + + + + +{{ template "footer.html" . }} diff --git a/idp/dex/web/templates/logout.html b/idp/dex/web/templates/logout.html new file mode 100644 index 000000000..b623d35af --- /dev/null +++ b/idp/dex/web/templates/logout.html @@ -0,0 +1,14 @@ +{{ template "header.html" . }} + +
+

Logged Out

+

You have been successfully logged out.

+ + {{ if .BackURL }} + + {{ end }} +
+ +{{ template "footer.html" . }} diff --git a/idp/dex/web/templates/password.html b/idp/dex/web/templates/password.html index 1d1b8282e..e1bfa7258 100755 --- a/idp/dex/web/templates/password.html +++ b/idp/dex/web/templates/password.html @@ -18,6 +18,7 @@ id="login" name="login" class="nb-input" + autocomplete="username" placeholder="Enter your {{ .UsernamePrompt | lower }}" {{ if .Username }}value="{{ .Username }}"{{ else }}autofocus{{ end }} required @@ -31,6 +32,7 @@ id="password" name="password" class="nb-input" + autocomplete="current-password" placeholder="Enter your password" {{ if .Invalid }}autofocus{{ end }} required diff --git a/idp/dex/web/templates/totp_verify.html b/idp/dex/web/templates/totp_verify.html new file mode 100644 index 000000000..8286418f0 --- /dev/null +++ b/idp/dex/web/templates/totp_verify.html @@ -0,0 +1,44 @@ +{{ template "header.html" . }} + +
+

Two-factor authentication

+ {{ if not (eq .QRCode "") }} +

Scan the QR code below using your authenticator app, then enter the code.

+
+ QR code +
+ {{ else }} +

Enter the code from your authenticator app.

+ {{ end }} + +
+ {{ if .Invalid }} +
+ Invalid code. Please try again. +
+ {{ end }} + +
+ + +
+ + +
+
+ +{{ template "footer.html" . }} diff --git a/idp/dex/web/templates/webauthn_verify.html b/idp/dex/web/templates/webauthn_verify.html new file mode 100644 index 000000000..be7c938ae --- /dev/null +++ b/idp/dex/web/templates/webauthn_verify.html @@ -0,0 +1,12 @@ +{{ template "header.html" . }} + + + + +{{ template "footer.html" . }} diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 9b2ec2989..ea94245d5 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -113,30 +114,47 @@ func (s *BaseServer) AccountManager() account.Manager { }) } +func isMFAEnabledForAccount(accounts []*types.Account) bool { + if len(accounts) != 1 { + return false + } + + settings := accounts[0].Settings + return settings != nil && settings.LocalMfaEnabled +} + func (s *BaseServer) IdpManager() idp.Manager { return Create(s, func() idp.Manager { - var idpManager idp.Manager - var err error - // Use embedded IdP service if embedded Dex is configured and enabled. // Legacy IdpManager won't be used anymore even if configured. embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled if embeddedEnabled { - idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics()) + embeddedMgr, err := idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics()) if err != nil { log.Fatalf("failed to create embedded IDP service: %v", err) } - return idpManager + + if val := isMFAEnabledForAccount(s.Store().GetAllAccounts(context.Background())); val { + if err := embeddedMgr.SetMFAEnabled(context.Background(), val); err != nil { + log.Errorf("failed to set MFA enabled on embedded IDP: %v", err) + } + } + + return embeddedMgr } // Fall back to external IdP service if s.Config.IdpManagerConfig != nil { - idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) + idpManager, err := idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) if err != nil { log.Fatalf("failed to create IDP service: %v", err) } + + return idpManager } - return idpManager + + + return nil }) } diff --git a/management/server/account.go b/management/server/account.go index 45b99839f..364c0c37b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -386,6 +386,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { return nil, err } + if err = am.handleLocalMfaSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { + return nil, err + } if oldSettings.DNSDomain != newSettings.DNSDomain { eventMeta := map[string]any{ "old_dns_domain": oldSettings.DNSDomain, @@ -602,6 +605,29 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. return nil } +func (am *DefaultAccountManager) handleLocalMfaSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { + if oldSettings.LocalMfaEnabled == newSettings.LocalMfaEnabled { + return nil + } + + embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager) + if !ok { + return nil + } + + if err := embeddedIdp.SetMFAEnabled(ctx, newSettings.LocalMfaEnabled); err != nil { + return fmt.Errorf("failed to toggle MFA: %w", err) + } + + if newSettings.LocalMfaEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLocalMfaEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLocalMfaDisabled, nil) + } + + return nil +} + func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { //nolint diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 2388115ff..6c781a952 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -236,6 +236,11 @@ const ( // AccountIPv6Disabled indicates that a user disabled IPv6 overlay for the account AccountIPv6Disabled Activity = 122 + // AccountLocalMfaEnabled indicates that a user enabled TOTP MFA for local users + AccountLocalMfaEnabled Activity = 123 + // AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users + AccountLocalMfaDisabled Activity = 124 + AccountDeleted Activity = 99999 ) @@ -386,6 +391,9 @@ var activityMap = map[Activity]Code{ AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"}, AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"}, + AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"}, + AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"}, + DomainAdded: {"Domain added", "domain.add"}, DomainDeleted: {"Domain deleted", "domain.delete"}, DomainValidated: {"Domain validated", "domain.validate"}, diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 31820b9fb..209d593bd 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -277,6 +277,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS if req.Settings.AutoUpdateAlways != nil { returnSettings.AutoUpdateAlways = *req.Settings.AutoUpdateAlways } + if req.Settings.LocalMfaEnabled != nil { + returnSettings.LocalMfaEnabled = *req.Settings.LocalMfaEnabled + } if req.Settings.Ipv6EnabledGroups != nil { returnSettings.IPv6EnabledGroups = *req.Settings.Ipv6EnabledGroups } @@ -412,6 +415,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A Ipv6EnabledGroups: &settings.IPv6EnabledGroups, EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, LocalAuthDisabled: &settings.LocalAuthDisabled, + LocalMfaEnabled: &settings.LocalMfaEnabled, } if settings.NetworkRange.IsValid() { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index fc1517a30..8db76719c 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -131,6 +131,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), + LocalMfaEnabled: br(false), }, expectedArray: true, expectedID: accountID, @@ -157,6 +158,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), + LocalMfaEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -183,6 +185,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { AutoUpdateVersion: sr("latest"), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), + LocalMfaEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -209,6 +212,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), + LocalMfaEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -235,6 +239,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), + LocalMfaEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -261,6 +266,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), LocalAuthDisabled: br(false), + LocalMfaEnabled: br(false), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index 48d3221cc..a1852a8bc 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -2,9 +2,11 @@ package idp import ( "context" + "encoding/base64" "errors" "fmt" "net/http" + "os" "strings" "github.com/dexidp/dex/storage" @@ -17,12 +19,13 @@ import ( ) const ( - staticClientDashboard = "netbird-dashboard" - staticClientCLI = "netbird-cli" - defaultCLIRedirectURL1 = "http://localhost:53000/" - defaultCLIRedirectURL2 = "http://localhost:54000/" - defaultScopes = "openid profile email groups" - defaultUserIDClaim = "sub" + staticClientDashboard = "netbird-dashboard" + staticClientCLI = "netbird-cli" + defaultCLIRedirectURL1 = "http://localhost:53000/" + defaultCLIRedirectURL2 = "http://localhost:54000/" + defaultScopes = "openid profile email groups" + defaultUserIDClaim = "sub" + sessionCookieEncryptionKeyEnv = "NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY" ) // EmbeddedIdPConfig contains configuration for the embedded Dex OIDC identity provider @@ -49,6 +52,26 @@ type EmbeddedIdPConfig struct { // Existing local users are preserved and will be able to login again if re-enabled. // Cannot be enabled if no external identity provider connectors are configured. LocalAuthDisabled bool + // MfaSessionMaxLifetime is the maximum MFA session duration from creation (e.g. "24h"). + // Defaults to "24h" if empty. + MfaSessionMaxLifetime string + // MfaSessionIdleTimeout is the idle timeout after which the MFA session expires (e.g. "1h"). + // Defaults to "1h" if empty. + MfaSessionIdleTimeout string + // MfaSessionRememberMe controls the default state of the "remember me" checkbox on the + // login screen. When true, the session cookie persists across browser tabs/restarts so + // MFA is not re-prompted until the session expires. Defaults to false. + MfaSessionRememberMe bool + // SessionCookieEncryptionKey is the optional AES key used to encrypt embedded IdP session cookies. + // It can also be set with NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY. The value must be 16, 24, or 32 + // bytes when provided as a raw string, or base64-encoded to one of those lengths. + SessionCookieEncryptionKey string + // Dashboard Post logout redirect URIs, these are required to tell + // Dex what to allow when an RP-Initiated logout is started by the frontend + // at least one of these must match the dashboard base URL or the dashboard + // DASHBOARD_POST_LOGOUT_URL environment variable + // WARNING: Dex only uses exact match, not wildcards + DashboardPostLogoutRedirectURIs []string // StaticConnectors are additional connectors to seed during initialization StaticConnectors []dex.Connector } @@ -126,6 +149,11 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { // todo: resolve import cycle dashboardRedirectURIs = append(dashboardRedirectURIs, baseURL+"/api/reverse-proxy/callback") + dashboardPostLogoutRedirectURIs := c.DashboardPostLogoutRedirectURIs + // It is safe to assume that most installations will share the location of the + // MGMT api and the dashboard, adding baseURL means less configuration for the instance admin + dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL) + cfg := &dex.YAMLConfig{ Issuer: c.Issuer, Storage: dex.Storage{ @@ -148,10 +176,11 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { EnablePasswordDB: true, StaticClients: []storage.Client{ { - ID: staticClientDashboard, - Name: "NetBird Dashboard", - Public: true, - RedirectURIs: dashboardRedirectURIs, + ID: staticClientDashboard, + Name: "NetBird Dashboard", + Public: true, + RedirectURIs: dashboardRedirectURIs, + PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs), }, { ID: staticClientCLI, @@ -163,6 +192,12 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { StaticConnectors: c.StaticConnectors, } + // Always initialize MFA providers and sessions so TOTP can be toggled at runtime. + // MFAChain on clients is NOT set here — it's synced from the DB setting on startup. + if err := configureMFA(cfg, c.MfaSessionMaxLifetime, c.MfaSessionIdleTimeout, c.MfaSessionRememberMe, c.SessionCookieEncryptionKey); err != nil { + return nil, err + } + // Add owner user if provided if c.Owner != nil && c.Owner.Email != "" && c.Owner.Hash != "" { username := c.Owner.Username @@ -182,6 +217,100 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { return cfg, nil } +// Due to how the frontend generates the logout, sometimes it appends a trailing slash +// and because Dex only allows exact matches, we need to make sure we always have both +// versions of each provided uri +func sanitizePostLogoutRedirectURIs(uris []string) []string { + result := make([]string, 0) + for _, uri := range uris { + if strings.HasSuffix(uri, "/") { + result = append(result, uri) + result = append(result, strings.TrimSuffix(uri, "/")) + } else { + result = append(result, uri) + result = append(result, uri+"/") + } + } + + return result +} + +func configureMFA(cfg *dex.YAMLConfig, sessionMaxLifetime, sessionIdleTimeout string, rememberMe bool, sessionCookieEncryptionKey string) error { + cfg.MFA.Authenticators = []dex.MFAAuthenticator{{ + ID: "default-totp", + // Has to be caps otherwise it will fail + Type: "TOTP", + Config: map[string]interface{}{ + "issuer": "NetBird", + }, + ConnectorTypes: []string{"local"}, + }} + + if sessionMaxLifetime == "" { + sessionMaxLifetime = "24h" + } + if sessionIdleTimeout == "" { + sessionIdleTimeout = "1h" + } + + cookieEncryptionKey, err := resolveSessionCookieEncryptionKey(sessionCookieEncryptionKey) + if err != nil { + return err + } + + cfg.Sessions = &dex.Sessions{ + CookieName: "netbird-session", + AbsoluteLifetime: sessionMaxLifetime, + ValidIfNotUsedFor: sessionIdleTimeout, + RememberMeCheckedByDefault: &rememberMe, + SSOSharedWithDefault: "all", + CookieEncryptionKey: cookieEncryptionKey, + } + // Absolutely required, otherwise the dex server will omit the MFA configuration entirely + os.Setenv("DEX_SESSIONS_ENABLED", "true") + + // Note: MFAChain on clients is NOT set here. + // It is toggled at runtime via SetMFAEnabled() based on the account settings DB value. + return nil +} + +func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) { + key := strings.TrimSpace(configuredKey) + if key == "" { + key = strings.TrimSpace(os.Getenv(sessionCookieEncryptionKeyEnv)) + } + if key == "" { + return "", nil + } + + if validSessionCookieEncryptionKeyLength(len([]byte(key))) { + return key, nil + } + + for _, encoding := range []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } { + decoded, err := encoding.DecodeString(key) + if err == nil && validSessionCookieEncryptionKeyLength(len(decoded)) { + return string(decoded), nil + } + } + + return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key))) +} + +func validSessionCookieEncryptionKeyLength(length int) bool { + switch length { + case 16, 24, 32: + return true + default: + return false + } +} + // Compile-time check that EmbeddedIdPManager implements Manager interface var _ Manager = (*EmbeddedIdPManager)(nil) @@ -215,6 +344,7 @@ type EmbeddedIdPManager struct { provider *dex.Provider appMetrics telemetry.AppMetrics config EmbeddedIdPConfig + mfaEnabled bool } // NewEmbeddedIdPManager creates a new instance of EmbeddedIdPManager from a configuration. @@ -641,6 +771,27 @@ func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool { return m.config.LocalAuthDisabled } +// SetMFAEnabled enables or disables TOTP MFA for local users by updating the MFAChain on OAuth2 clients. +func (m *EmbeddedIdPManager) SetMFAEnabled(ctx context.Context, enabled bool) error { + var mfaChain []string + if enabled { + mfaChain = []string{"default-totp"} + } + if err := m.provider.SetClientsMFAChain(ctx, []string{ + staticClientCLI, + staticClientDashboard, + }, mfaChain); err != nil { + return fmt.Errorf("failed to set MFA enabled=%v: %w", enabled, err) + } + m.mfaEnabled = enabled + return nil +} + +// IsMFAEnabled returns whether TOTP MFA is currently enabled for local users. +func (m *EmbeddedIdPManager) IsMFAEnabled() bool { + return m.mfaEnabled +} + // HasNonLocalConnectors checks if there are any identity provider connectors other than local. func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) { return m.provider.HasNonLocalConnectors(ctx) diff --git a/management/server/idp/embedded_test.go b/management/server/idp/embedded_test.go index 4dda483fb..09dc67614 100644 --- a/management/server/idp/embedded_test.go +++ b/management/server/idp/embedded_test.go @@ -2,6 +2,7 @@ package idp import ( "context" + "encoding/base64" "os" "path/filepath" "testing" @@ -313,6 +314,72 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) { }) } +func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) { + t.Setenv(sessionCookieEncryptionKeyEnv, "") + + rawKey := "0123456789abcdef0123456789abcdef" + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + SessionCookieEncryptionKey: base64.StdEncoding.EncodeToString([]byte(rawKey)), + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(t.TempDir(), "dex.db"), + }, + }, + } + + yamlConfig, err := config.ToYAMLConfig() + require.NoError(t, err) + require.NotNil(t, yamlConfig.Sessions) + assert.Equal(t, rawKey, yamlConfig.Sessions.CookieEncryptionKey) +} + +func TestResolveSessionCookieEncryptionKey(t *testing.T) { + rawKey := "0123456789abcdef0123456789abcdef" + + t.Run("uses raw configured key", func(t *testing.T) { + t.Setenv(sessionCookieEncryptionKeyEnv, "") + + key, err := resolveSessionCookieEncryptionKey(rawKey) + require.NoError(t, err) + assert.Equal(t, rawKey, key) + }) + + t.Run("uses base64 configured key", func(t *testing.T) { + t.Setenv(sessionCookieEncryptionKeyEnv, "") + + key, err := resolveSessionCookieEncryptionKey(base64.StdEncoding.EncodeToString([]byte(rawKey))) + require.NoError(t, err) + assert.Equal(t, rawKey, key) + }) + + t.Run("falls back to env var", func(t *testing.T) { + t.Setenv(sessionCookieEncryptionKeyEnv, rawKey) + + key, err := resolveSessionCookieEncryptionKey("") + require.NoError(t, err) + assert.Equal(t, rawKey, key) + }) + + t.Run("empty key disables encryption", func(t *testing.T) { + t.Setenv(sessionCookieEncryptionKeyEnv, "") + + key, err := resolveSessionCookieEncryptionKey("") + require.NoError(t, err) + assert.Empty(t, key) + }) + + t.Run("rejects invalid key length", func(t *testing.T) { + t.Setenv(sessionCookieEncryptionKeyEnv, "") + + _, err := resolveSessionCookieEncryptionKey("32") + require.Error(t, err) + assert.Contains(t, err.Error(), sessionCookieEncryptionKeyEnv) + }) +} + func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) { ctx := context.Background() diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 973101ce3..065a0d306 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -236,7 +236,6 @@ func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([ Where(accountAndPeerIDQueryCondition, accountID, peerID). Order("created_at DESC"). Find(&jobs).Error - if err != nil { log.WithContext(ctx).Errorf("failed to fetch jobs from store: %s", err) return nil, err @@ -463,7 +462,6 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. return nil }) - if err != nil { return err } @@ -1514,6 +1512,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, settings_network_range_v6, settings_ipv6_enabled_groups, settings_lazy_connection_enabled, + settings_local_mfa_enabled, -- Embedded ExtraSettings settings_extra_peer_approval_enabled, settings_extra_user_approval_required, settings_extra_integrated_validator, settings_extra_integrated_validator_groups @@ -1535,6 +1534,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc sNetworkRangeV6 sql.NullString sIPv6EnabledGroups sql.NullString sLazyConnectionEnabled sql.NullBool + sLocalMFAEnabled sql.NullBool sExtraPeerApprovalEnabled sql.NullBool sExtraUserApprovalRequired sql.NullBool sExtraIntegratedValidator sql.NullString @@ -1557,6 +1557,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, &sNetworkRangeV6, &sIPv6EnabledGroups, &sLazyConnectionEnabled, + &sLocalMFAEnabled, &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, ) @@ -1619,6 +1620,9 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc if sLazyConnectionEnabled.Valid { account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool } + if sLocalMFAEnabled.Valid { + account.Settings.LocalMfaEnabled = sLocalMFAEnabled.Bool + } if sJWTAllowGroups.Valid { _ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups) } @@ -3061,7 +3065,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer GroupID: groupID, PeerID: peerID, }).Error - if err != nil { return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err) } @@ -3081,7 +3084,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, DoNothing: true, }).Create(peer).Error - if err != nil { log.WithContext(ctx).Errorf("failed to add peer %s to group %s for account %s: %v", peerID, groupID, accountID, err) return status.Errorf(status.Internal, "failed to add peer to group") @@ -3094,7 +3096,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error { err := s.db. Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error - if err != nil { log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err) return status.Errorf(status.Internal, "failed to remove peer from group") @@ -3107,7 +3108,6 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error { err := s.db. Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error - if err != nil { log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err) return status.Errorf(status.Internal, "failed to remove peer from all groups") @@ -4964,7 +4964,6 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service return nil }) - if err != nil { log.WithContext(ctx).Errorf("failed to update service to store: %v", err) return status.Errorf(status.Internal, "failed to update service to store") @@ -5620,7 +5619,6 @@ func (s *SqlStore) getClusterUnanimousCapability(ctx context.Context, clusterAdd Where("cluster_address = ? AND status = ? AND last_seen > ?", clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)). Scan(&result).Error - if err != nil { log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err) return nil @@ -5662,7 +5660,6 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column Where("cluster_address = ? AND status = ? AND last_seen > ?", clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)). Scan(&result).Error - if err != nil { log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err) return nil diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 264a018d4..97ffa5e76 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -80,6 +80,10 @@ type Settings struct { // LocalAuthDisabled indicates if local (email/password) authentication is disabled. // This is a runtime-only field, not stored in the database. LocalAuthDisabled bool `gorm:"-"` + + // LocalMfaEnabled indicates if TOTP MFA is enabled for local users. + // Only applicable when the embedded IDP is enabled. + LocalMfaEnabled bool } // Copy copies the Settings struct @@ -108,6 +112,7 @@ func (s *Settings) Copy() *Settings { IPv6EnabledGroups: slices.Clone(s.IPv6EnabledGroups), EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, LocalAuthDisabled: s.LocalAuthDisabled, + LocalMfaEnabled: s.LocalMfaEnabled, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 8e6ee54cc..82fca0782 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -381,6 +381,10 @@ components: type: boolean readOnly: true example: false + local_mfa_enabled: + description: Enables or disables TOTP multi-factor authentication for local users. Only applicable when the embedded identity provider is enabled. + type: boolean + example: false ipv6_enabled_groups: description: List of group IDs whose peers receive IPv6 overlay addresses. Peers not in any of these groups will not be allocated an IPv6 address. New accounts default to the All group. type: array diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index f8ea07be7..4b94ea01c 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1486,6 +1486,9 @@ type AccountSettings struct { // LocalAuthDisabled Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field. LocalAuthDisabled *bool `json:"local_auth_disabled,omitempty"` + // LocalMfaEnabled Enables or disables TOTP multi-factor authentication for local users. Only applicable when the embedded identity provider is enabled. + LocalMfaEnabled *bool `json:"local_mfa_enabled,omitempty"` + // NetworkRange Allows to define a custom network range for the account in CIDR format NetworkRange *string `json:"network_range,omitempty"` From afb83b3049c3c0b5b578f051e277562c6b8e513d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 11 May 2026 16:58:49 +0900 Subject: [PATCH 12/27] [client] Use unique temp file and clean up on failure when writing ssh config (#6064) --- client/ssh/config/manager.go | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index 01822ead6..b58bf2233 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -229,18 +229,31 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string { func (m *Manager) writeSSHConfig(sshConfig string) error { sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) - sshConfigPathTmp := sshConfigPath + ".tmp" if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err) } - if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil { - return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err) + tmp, err := os.CreateTemp(m.sshConfigDir, m.sshConfigFile+".*.tmp") + if err != nil { + return fmt.Errorf("create temp SSH config: %w", err) + } + tmpPath := tmp.Name() + defer func() { + if err := os.Remove(tmpPath); err != nil && !os.IsNotExist(err) { + log.Debugf("remove temp SSH config %s: %v", tmpPath, err) + } + }() + if err := tmp.Close(); err != nil { + return fmt.Errorf("close temp SSH config %s: %w", tmpPath, err) } - if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil { - return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err) + if err := writeFileWithTimeout(tmpPath, []byte(sshConfig), 0644); err != nil { + return fmt.Errorf("write SSH config file %s: %w", tmpPath, err) + } + + if err := os.Rename(tmpPath, sshConfigPath); err != nil { + return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err) } log.Infof("Created NetBird SSH client config: %s", sshConfigPath) From a852b3bd34f82c3d548874db4eb2b1b186320225 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 11 May 2026 16:59:13 +0900 Subject: [PATCH 13/27] [client, proxy] Harden uspfilter conntrack and share TCP relay (#5936) --- .../firewall/uspfilter/conntrack/cap_test.go | 125 ++++++ client/firewall/uspfilter/conntrack/common.go | 46 ++ .../uspfilter/conntrack/defaults_desktop.go | 11 + .../uspfilter/conntrack/defaults_mobile.go | 13 + client/firewall/uspfilter/conntrack/icmp.go | 50 ++- client/firewall/uspfilter/conntrack/tcp.go | 398 +++++++++++++----- .../uspfilter/conntrack/tcp_rst_bugs_test.go | 100 +++++ .../conntrack/tcp_state_bugs_test.go | 235 +++++++++++ client/firewall/uspfilter/conntrack/udp.go | 52 ++- client/firewall/uspfilter/filter.go | 42 +- client/firewall/uspfilter/forwarder/icmp.go | 31 +- client/firewall/uspfilter/forwarder/tcp.go | 75 +--- client/firewall/uspfilter/forwarder/udp.go | 16 +- client/firewall/uspfilter/log/log.go | 137 +++--- client/firewall/uspfilter/nat.go | 21 +- client/ssh/client/client.go | 26 +- client/ssh/common.go | 60 --- client/ssh/proxy/proxy.go | 5 +- client/ssh/server/port_forwarding.go | 4 +- client/ssh/server/server.go | 33 +- proxy/internal/tcp/peekedconn.go | 6 + proxy/internal/tcp/relay.go | 156 ------- proxy/internal/tcp/relay_test.go | 15 +- proxy/internal/tcp/router.go | 8 +- util/netrelay/relay.go | 238 +++++++++++ util/netrelay/relay_test.go | 221 ++++++++++ 26 files changed, 1629 insertions(+), 495 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/cap_test.go create mode 100644 client/firewall/uspfilter/conntrack/defaults_desktop.go create mode 100644 client/firewall/uspfilter/conntrack/defaults_mobile.go create mode 100644 client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go create mode 100644 client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go delete mode 100644 proxy/internal/tcp/relay.go create mode 100644 util/netrelay/relay.go create mode 100644 util/netrelay/relay_test.go diff --git a/client/firewall/uspfilter/conntrack/cap_test.go b/client/firewall/uspfilter/conntrack/cap_test.go new file mode 100644 index 000000000..ee6b72e7f --- /dev/null +++ b/client/firewall/uspfilter/conntrack/cap_test.go @@ -0,0 +1,125 @@ +package conntrack + +import ( + "net/netip" + "testing" + + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" +) + +func TestTCPCapEvicts(t *testing.T) { + t.Setenv(EnvTCPMaxEntries, "4") + + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + require.Equal(t, 4, tracker.maxEntries) + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + for i := 0; i < 10; i++ { + tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0) + } + require.LessOrEqual(t, len(tracker.connections), 4, + "TCP table must not exceed the configured cap") + require.Greater(t, len(tracker.connections), 0, + "some entries must remain after eviction") + + // The most recently admitted flow must be present: eviction must make + // room for new entries, not silently drop them. + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80}, + "newest TCP flow must be admitted after eviction") + // A pre-cap flow must have been evicted to fit the last one. + require.NotContains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80}, + "oldest TCP flow should have been evicted") +} + +func TestTCPCapPrefersTombstonedForEviction(t *testing.T) { + t.Setenv(EnvTCPMaxEntries, "3") + + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + // Fill to cap with 3 live connections. + for i := 0; i < 3; i++ { + tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0) + } + require.Len(t, tracker.connections, 3) + + // Tombstone one by sending RST through IsValidInbound. + tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80} + require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0)) + require.True(t, tracker.connections[tombstonedKey].IsTombstone()) + + // Another live connection forces eviction. The tombstone must go first. + tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0) + + _, tombstonedStillPresent := tracker.connections[tombstonedKey] + require.False(t, tombstonedStillPresent, + "tombstoned entry should be evicted before live entries") + require.LessOrEqual(t, len(tracker.connections), 3) + + // Both live pre-cap entries must survive: eviction must prefer the + // tombstone, not just satisfy the size bound by dropping any entry. + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80}, + "live entries must not be evicted while a tombstone exists") + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80}, + "live entries must not be evicted while a tombstone exists") +} + +func TestUDPCapEvicts(t *testing.T) { + t.Setenv(EnvUDPMaxEntries, "5") + + tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) + defer tracker.Close() + require.Equal(t, 5, tracker.maxEntries) + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + for i := 0; i < 12; i++ { + tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0) + } + require.LessOrEqual(t, len(tracker.connections), 5) + require.Greater(t, len(tracker.connections), 0) + + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53}, + "newest UDP flow must be admitted after eviction") + require.NotContains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53}, + "oldest UDP flow should have been evicted") +} + +func TestICMPCapEvicts(t *testing.T) { + t.Setenv(EnvICMPMaxEntries, "3") + + tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) + defer tracker.Close() + require.Equal(t, 3, tracker.maxEntries) + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0) + for i := 0; i < 8; i++ { + tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64) + } + require.LessOrEqual(t, len(tracker.connections), 3) + require.Greater(t, len(tracker.connections), 0) + + require.Contains(t, tracker.connections, + ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)}, + "newest ICMP flow must be admitted after eviction") + require.NotContains(t, tracker.connections, + ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)}, + "oldest ICMP flow should have been evicted") +} diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 88e90317c..e497a0bff 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -3,15 +3,61 @@ package conntrack import ( "net" "net/netip" + "os" "strconv" "sync/atomic" "time" "github.com/google/uuid" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) +// evictSampleSize bounds how many map entries we scan per eviction call. +// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU +// heuristic is good enough for a conntrack table that only overflows under +// abuse. +const evictSampleSize = 8 + +// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to +// def on empty or invalid; logs a warning on invalid. +func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration { + v := os.Getenv(name) + if v == "" { + return def + } + d, err := time.ParseDuration(v) + if err != nil { + logger.Warn3("invalid %s=%q: %v, using default", name, v, err) + return def + } + if d <= 0 { + logger.Warn2("invalid %s=%q: must be positive, using default", name, v) + return def + } + return d +} + +// envInt parses an os.Getenv(name) as an int. Falls back to def on empty, +// invalid, or non-positive. Logs a warning on invalid input. +func envInt(logger *nblog.Logger, name string, def int) int { + v := os.Getenv(name) + if v == "" { + return def + } + n, err := strconv.Atoi(v) + switch { + case err != nil: + logger.Warn3("invalid %s=%q: %v, using default", name, v, err) + return def + case n <= 0: + logger.Warn2("invalid %s=%q: must be positive, using default", name, v) + return def + } + return n +} + // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { FlowId uuid.UUID diff --git a/client/firewall/uspfilter/conntrack/defaults_desktop.go b/client/firewall/uspfilter/conntrack/defaults_desktop.go new file mode 100644 index 000000000..2f07f5984 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/defaults_desktop.go @@ -0,0 +1,11 @@ +//go:build !ios && !android + +package conntrack + +// Default per-tracker entry caps on desktop/server platforms. These mirror +// typical Linux netfilter nf_conntrack_max territory with ample headroom. +const ( + DefaultMaxTCPEntries = 65536 + DefaultMaxUDPEntries = 16384 + DefaultMaxICMPEntries = 2048 +) diff --git a/client/firewall/uspfilter/conntrack/defaults_mobile.go b/client/firewall/uspfilter/conntrack/defaults_mobile.go new file mode 100644 index 000000000..c9e05d229 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/defaults_mobile.go @@ -0,0 +1,13 @@ +//go:build ios || android + +package conntrack + +// Default per-tracker entry caps on mobile platforms. iOS network extensions +// are capped at ~50 MB; Android runs under aggressive memory pressure. These +// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack +// is ~200 B plus map overhead). +const ( + DefaultMaxTCPEntries = 4096 + DefaultMaxUDPEntries = 2048 + DefaultMaxICMPEntries = 512 +) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index a48215ca9..3c96548b5 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -50,6 +50,9 @@ type ICMPConnTrack struct { ICMPCode uint8 } +// EnvICMPMaxEntries caps the ICMP conntrack table size. +const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX" + // ICMPTracker manages ICMP connection states type ICMPTracker struct { logger *nblog.Logger @@ -58,6 +61,7 @@ type ICMPTracker struct { cleanupTicker *time.Ticker tickerCancel context.CancelFunc mutex sync.RWMutex + maxEntries int flowLogger nftypes.FlowLogger } @@ -171,6 +175,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), tickerCancel: cancel, + maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries), flowLogger: flowLogger, } @@ -257,7 +262,9 @@ func (t *ICMPTracker) track( // non echo requests don't need tracking if typ != uint8(layers.ICMPv4TypeEchoRequest) { - t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + } t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) return } @@ -276,10 +283,15 @@ func (t *ICMPTracker) track( conn.UpdateCounters(direction, size) t.mutex.Lock() + if t.maxEntries > 0 && len(t.connections) >= t.maxEntries { + t.evictOneLocked() + } t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + } t.sendEvent(nftypes.TypeStart, conn, ruleId) } @@ -323,6 +335,34 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) { } } +// evictOneLocked removes one entry to make room. Caller must hold t.mutex. +// Bounded sample scan: picks the oldest among up to evictSampleSize entries. +func (t *ICMPTracker) evictOneLocked() { + var candKey ICMPConnKey + var candSeen int64 + haveCand := false + sampled := 0 + + for k, c := range t.connections { + seen := c.lastSeen.Load() + if !haveCand || seen < candSeen { + candKey = k + candSeen = seen + haveCand = true + } + sampled++ + if sampled >= evictSampleSize { + break + } + } + if haveCand { + if evicted := t.connections[candKey]; evicted != nil { + t.sendEvent(nftypes.TypeEnd, evicted, nil) + } + delete(t.connections, candKey) + } +} + func (t *ICMPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() @@ -331,8 +371,10 @@ func (t *ICMPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", - key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } t.sendEvent(nftypes.TypeEnd, conn, nil) } } diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 335a3abab..9edc9af22 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -38,6 +38,27 @@ const ( TCPHandshakeTimeout = 60 * time.Second // TCPCleanupInterval is how often we check for stale connections TCPCleanupInterval = 5 * time.Minute + // FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states. + // Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait. + FinWaitTimeout = 60 * time.Second + // CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps + // holding CloseWait longer than this should bump the env var. + CloseWaitTimeout = 60 * time.Second + // LastAckTimeout bounds LAST_ACK. Matches Linux default. + LastAckTimeout = 30 * time.Second +) + +// Env vars to override per-state teardown timeouts. Values parsed by +// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the +// defaults above with a warning. +const ( + EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT" + EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT" + EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT" + + // EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries + // (tombstones first) are evicted when the cap is reached. + EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX" ) // TCPState represents the state of a TCP connection @@ -133,14 +154,18 @@ func (t *TCPConnTrack) SetTombstone() { // TCPTracker manages TCP connection states type TCPTracker struct { - logger *nblog.Logger - connections map[ConnKey]*TCPConnTrack - mutex sync.RWMutex - cleanupTicker *time.Ticker - tickerCancel context.CancelFunc - timeout time.Duration - waitTimeout time.Duration - flowLogger nftypes.FlowLogger + logger *nblog.Logger + connections map[ConnKey]*TCPConnTrack + mutex sync.RWMutex + cleanupTicker *time.Ticker + tickerCancel context.CancelFunc + timeout time.Duration + waitTimeout time.Duration + finWaitTimeout time.Duration + closeWaitTimeout time.Duration + lastAckTimeout time.Duration + maxEntries int + flowLogger nftypes.FlowLogger } // NewTCPTracker creates a new TCP connection tracker @@ -155,13 +180,17 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp ctx, cancel := context.WithCancel(context.Background()) tracker := &TCPTracker{ - logger: logger, - connections: make(map[ConnKey]*TCPConnTrack), - cleanupTicker: time.NewTicker(TCPCleanupInterval), - tickerCancel: cancel, - timeout: timeout, - waitTimeout: waitTimeout, - flowLogger: flowLogger, + logger: logger, + connections: make(map[ConnKey]*TCPConnTrack), + cleanupTicker: time.NewTicker(TCPCleanupInterval), + tickerCancel: cancel, + timeout: timeout, + waitTimeout: waitTimeout, + finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout), + closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout), + lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout), + maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries), + flowLogger: flowLogger, } go tracker.cleanupRoutine(ctx) @@ -209,6 +238,12 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla if exists || flags&TCPSyn == 0 { return } + // Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't + // create spurious conntrack entries. Not mandated by RFC 9293 but a + // common hardening (Linux netfilter/nftables rejects these too). + if !isValidFlagCombination(flags) { + return + } conn := &TCPConnTrack{ BaseConnTrack: BaseConnTrack{ @@ -225,20 +260,65 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.state.Store(int32(TCPStateNew)) conn.DNATOrigPort.Store(uint32(origPort)) - if origPort != 0 { - t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) - } else { - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if t.logger.Enabled(nblog.LevelTrace) { + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() + if t.maxEntries > 0 && len(t.connections) >= t.maxEntries { + t.evictOneLocked() + } t.connections[key] = conn t.mutex.Unlock() t.sendEvent(nftypes.TypeStart, conn, ruleID) } +// evictOneLocked removes one entry to make room. Caller must hold t.mutex. +// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map +// iteration order is randomized), preferring a tombstone. If no tombstone +// found in the sample, evicts the oldest among the sampled entries. O(1) +// worst case — cheap enough to run on every insert at cap during abuse. +func (t *TCPTracker) evictOneLocked() { + var candKey ConnKey + var candSeen int64 + haveCand := false + sampled := 0 + + for k, c := range t.connections { + if c.IsTombstone() { + delete(t.connections, k) + return + } + seen := c.lastSeen.Load() + if !haveCand || seen < candSeen { + candKey = k + candSeen = seen + haveCand = true + } + sampled++ + if sampled >= evictSampleSize { + break + } + } + if haveCand { + if evicted := t.connections[candKey]; evicted != nil { + // TypeEnd is already emitted at the state transition to + // TimeWait and when a connection is tombstoned. Only emit + // here when we're reaping a still-active flow. + if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() { + t.sendEvent(nftypes.TypeEnd, evicted, nil) + } + } + delete(t.connections, candKey) + } +} + // IsValidInbound checks if an inbound TCP packet matches a tracked connection func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool { key := ConnKey{ @@ -256,12 +336,19 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui return false } + // Reject illegal flag combinations regardless of state. These never belong + // to a legitimate flow and must not advance or tear down state. + if !isValidFlagCombination(flags) { + if t.logger.Enabled(nblog.LevelWarn) { + t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState()) + } + return false + } + currentState := conn.GetState() if !t.isValidStateForFlags(currentState, flags) { - t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) - // allow all flags for established for now - if currentState == TCPStateEstablished { - return true + if t.logger.Enabled(nblog.LevelWarn) { + t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) } return false } @@ -270,116 +357,208 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui return true } -// updateState updates the TCP connection state based on flags +// updateState updates the TCP connection state based on flags. func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) { - conn.UpdateLastSeen() conn.UpdateCounters(packetDir, size) + // Malformed flag combinations must not refresh lastSeen or drive state, + // otherwise spoofed packets keep a dead flow alive past its timeout. + if !isValidFlagCombination(flags) { + return + } + + conn.UpdateLastSeen() + currentState := conn.GetState() if flags&TCPRst != 0 { - if conn.CompareAndSwapState(currentState, TCPStateClosed) { - conn.SetTombstone() - t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", - key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn, nil) - } + // Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we + // cannot apply the RFC 5961 in-window RST check, so we conservatively + // reject RSTs that the spec would accept (TIME-WAIT with in-window + // SEQ, SynSent from same direction as own SYN, etc.). + t.handleRst(key, conn, currentState, packetDir) return } - var newState TCPState - switch currentState { - case TCPStateNew: - if flags&TCPSyn != 0 && flags&TCPAck == 0 { - if conn.Direction == nftypes.Egress { - newState = TCPStateSynSent - } else { - newState = TCPStateSynReceived - } - } + newState := nextState(currentState, conn.Direction, packetDir, flags) + if newState == 0 || !conn.CompareAndSwapState(currentState, newState) { + return + } + t.onTransition(key, conn, currentState, newState, packetDir) +} - case TCPStateSynSent: - if flags&TCPSyn != 0 && flags&TCPAck != 0 { - if packetDir != conn.Direction { - newState = TCPStateEstablished - } else { - // Simultaneous open - newState = TCPStateSynReceived - } - } +// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs +// from the SYN direction are ignored; otherwise the flow is tombstoned. +func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) { + // TimeWait exists to absorb late segments; don't let a late RST + // tombstone the entry and break same-4-tuple reuse. + if currentState == TCPStateTimeWait { + return + } + // A RST from the same direction as the SYN cannot be a legitimate + // response and must not tear down a half-open connection. + if currentState == TCPStateSynSent && packetDir == conn.Direction { + return + } + if !conn.CompareAndSwapState(currentState, TCPStateClosed) { + return + } + conn.SetTombstone() + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", + key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } + t.sendEvent(nftypes.TypeEnd, conn, nil) +} - case TCPStateSynReceived: - if flags&TCPAck != 0 && flags&TCPSyn == 0 { - if packetDir == conn.Direction { - newState = TCPStateEstablished - } - } +// stateTransition describes one state's transition logic. It receives the +// packet's flags plus whether the packet direction matches the connection's +// origin direction (same=true means same side as the SYN initiator). Return 0 +// for no transition. +type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState - case TCPStateEstablished: - if flags&TCPFin != 0 { - if packetDir == conn.Direction { - newState = TCPStateFinWait1 - } else { - newState = TCPStateCloseWait - } - } +// stateTable maps each state to its transition function. Centralized here so +// nextState stays trivial and each rule is easy to read in isolation. +var stateTable = map[TCPState]stateTransition{ + TCPStateNew: transNew, + TCPStateSynSent: transSynSent, + TCPStateSynReceived: transSynReceived, + TCPStateEstablished: transEstablished, + TCPStateFinWait1: transFinWait1, + TCPStateFinWait2: transFinWait2, + TCPStateClosing: transClosing, + TCPStateCloseWait: transCloseWait, + TCPStateLastAck: transLastAck, +} - case TCPStateFinWait1: - if packetDir != conn.Direction { - switch { - case flags&TCPFin != 0 && flags&TCPAck != 0: - newState = TCPStateClosing - case flags&TCPFin != 0: - newState = TCPStateClosing - case flags&TCPAck != 0: - newState = TCPStateFinWait2 - } - } +// nextState returns the target TCP state for the given current state and +// packet, or 0 if the packet does not trigger a transition. +func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState { + fn, ok := stateTable[currentState] + if !ok { + return 0 + } + return fn(flags, connDir, packetDir == connDir) +} - case TCPStateFinWait2: - if flags&TCPFin != 0 { - newState = TCPStateTimeWait +func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState { + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + if connDir == nftypes.Egress { + return TCPStateSynSent } + return TCPStateSynReceived + } + return 0 +} - case TCPStateClosing: - if flags&TCPAck != 0 { - newState = TCPStateTimeWait +func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPSyn != 0 && flags&TCPAck != 0 { + if same { + return TCPStateSynReceived // simultaneous open } + return TCPStateEstablished + } + return 0 +} - case TCPStateCloseWait: - if flags&TCPFin != 0 { - newState = TCPStateLastAck - } +func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPAck != 0 && flags&TCPSyn == 0 && same { + return TCPStateEstablished + } + return 0 +} - case TCPStateLastAck: - if flags&TCPAck != 0 { - newState = TCPStateClosed - } +func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPFin == 0 { + return 0 + } + if same { + return TCPStateFinWait1 + } + return TCPStateCloseWait +} + +// transFinWait1 handles the active-close peer response. A FIN carrying our +// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1: +// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves +// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2. +func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState { + if same { + return 0 + } + if flags&TCPFin != 0 && flags&TCPAck != 0 { + return TCPStateTimeWait + } + switch { + case flags&TCPFin != 0: + return TCPStateClosing + case flags&TCPAck != 0: + return TCPStateFinWait2 + } + return 0 +} + +// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances. +func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPFin != 0 && !same { + return TCPStateTimeWait + } + return 0 +} + +// transClosing completes a simultaneous close on the peer's ACK. +func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPAck != 0 && !same { + return TCPStateTimeWait + } + return 0 +} + +// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits. +func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPFin != 0 && same { + return TCPStateLastAck + } + return 0 +} + +// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits). +func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPAck != 0 && !same { + return TCPStateClosed + } + return 0 +} + +// onTransition handles logging and flow-event emission after a successful +// state transition. TimeWait and Closed are terminal for flow accounting. +func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) { + traceOn := t.logger.Enabled(nblog.LevelTrace) + if traceOn { + t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir) } - if newState != 0 && conn.CompareAndSwapState(currentState, newState) { - t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir) - - switch newState { - case TCPStateTimeWait: + switch to { + case TCPStateTimeWait: + if traceOn { t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn, nil) - - case TCPStateClosed: - conn.SetTombstone() + } + t.sendEvent(nftypes.TypeEnd, conn, nil) + case TCPStateClosed: + conn.SetTombstone() + if traceOn { t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn, nil) } + t.sendEvent(nftypes.TypeEnd, conn, nil) } } -// isValidStateForFlags checks if the TCP flags are valid for the current connection state +// isValidStateForFlags checks if the TCP flags are valid for the current +// connection state. Caller must have already verified the flag combination is +// legal via isValidFlagCombination. func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { - if !isValidFlagCombination(flags) { - return false - } if flags&TCPRst != 0 { if state == TCPStateSynSent { return flags&TCPAck != 0 @@ -449,15 +628,24 @@ func (t *TCPTracker) cleanup() { timeout = t.waitTimeout case TCPStateEstablished: timeout = t.timeout + case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing: + timeout = t.finWaitTimeout + case TCPStateCloseWait: + timeout = t.closeWaitTimeout + case TCPStateLastAck: + timeout = t.lastAckTimeout default: + // SynSent / SynReceived / New timeout = TCPHandshakeTimeout } if conn.timeoutExceeded(timeout) { delete(t.connections, key) - t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", - key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", + key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } // event already handled by state change if currentState != TCPStateTimeWait { diff --git a/client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go b/client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go new file mode 100644 index 000000000..81d4f5710 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go @@ -0,0 +1,100 @@ +package conntrack + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/require" +) + +// RST hygiene tests: the tracker currently closes the flow on any RST that +// matches the 4-tuple, regardless of direction or state. These tests cover +// the minimum checks we want (no SEQ tracking). + +func TestTCPRstInSynSentWrongDirection(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateSynSent, conn.GetState()) + + // A RST arriving in the same direction as the SYN (i.e. TrackOutbound) + // cannot be a legitimate response. It must not close the connection. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0) + require.Equal(t, TCPStateSynSent, conn.GetState(), + "RST in same direction as SYN must not close connection") + require.False(t, conn.IsTombstone()) +} + +func TestTCPRstInTimeWaitIgnored(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + // Drive to TIME-WAIT via active close. + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + conn := tracker.connections[key] + require.Equal(t, TCPStateTimeWait, conn.GetState()) + require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned") + + // Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT + // exists to absorb late segments). + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) + require.Equal(t, TCPStateTimeWait, conn.GetState(), + "RST in TIME-WAIT must not transition state") + require.False(t, conn.IsTombstone(), + "RST in TIME-WAIT must not tombstone the entry") +} + +func TestTCPIllegalFlagCombos(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + conn := tracker.connections[key] + + // Illegal combos must be rejected and must not change state. + combos := []struct { + name string + flags uint8 + }{ + {"SYN+RST", TCPSyn | TCPRst}, + {"FIN+RST", TCPFin | TCPRst}, + {"SYN+FIN", TCPSyn | TCPFin}, + {"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst}, + } + + for _, c := range combos { + t.Run(c.name, func(t *testing.T) { + before := conn.GetState() + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0) + require.False(t, valid, "illegal flag combo must be rejected: %s", c.name) + require.Equal(t, before, conn.GetState(), + "illegal flag combo must not change state") + require.False(t, conn.IsTombstone()) + }) + } +} diff --git a/client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go b/client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go new file mode 100644 index 000000000..32112cd58 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go @@ -0,0 +1,235 @@ +package conntrack + +import ( + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// These tests exercise cases where the TCP state machine currently advances +// on retransmitted or wrong-direction segments and tears the flow down +// prematurely. They are expected to fail until the direction checks are added. + +func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Peer sends FIN -> CloseWait (our app has not yet closed). + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid) + conn := tracker.connections[key] + require.Equal(t, TCPStateCloseWait, conn.GetState()) + + // Peer retransmits their FIN (ACK may have been delayed). We have NOT + // sent our FIN yet, so state must remain CloseWait. + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid, "retransmitted peer FIN must still be accepted") + require.Equal(t, TCPStateCloseWait, conn.GetState(), + "retransmitted peer FIN must not advance CloseWait to LastAck") + + // Our app finally closes -> LastAck. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateLastAck, conn.GetState()) + + // Peer ACK closes. + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateClosed, conn.GetState()) +} + +func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // We initiate close. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + conn := tracker.connections[key] + require.Equal(t, TCPStateFinWait2, conn.GetState()) + + // Stray retransmit of our own FIN (same direction as originator) must + // NOT advance FinWait2 to TimeWait; only the peer's FIN should. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateFinWait2, conn.GetState(), + "own FIN retransmit must not advance FinWait2 to TimeWait") + + // Peer FIN -> TimeWait. + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateTimeWait, conn.GetState()) +} + +func TestTCPLastAckDirectionCheck(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateLastAck, conn.GetState()) + + // Our own ACK retransmit (same direction as originator) must NOT close. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + require.Equal(t, TCPStateLastAck, conn.GetState(), + "own ACK retransmit in LastAck must not transition to Closed") + + // Peer's ACK -> Closed. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) + require.Equal(t, TCPStateClosed, conn.GetState()) +} + +func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateFinWait1, conn.GetState()) + + // Our own ACK retransmit (same direction as originator) must not advance. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + require.Equal(t, TCPStateFinWait1, conn.GetState(), + "own ACK in FinWait1 must not advance to FinWait2") +} + +func TestTCPPerStateTeardownTimeouts(t *testing.T) { + // Verify cleanup reaps entries in each teardown state at the configured + // per-state timeout, not at the single handshake timeout. + t.Setenv(EnvTCPFinWaitTimeout, "50ms") + t.Setenv(EnvTCPCloseWaitTimeout, "80ms") + t.Setenv(EnvTCPLastAckTimeout, "30ms") + + dstIP := netip.MustParseAddr("100.64.0.2") + dstPort := uint16(80) + + // Drives a connection to the target state, forces its lastSeen well + // beyond the configured timeout, runs cleanup, and asserts reaping. + cases := []struct { + name string + // drive takes a fresh tracker and returns the conn key after + // transitioning the flow into the intended teardown state. + drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) + }{ + { + name: "FinWait1", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1 + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1 + }, + }, + { + name: "FinWait2", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1 + require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2 + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2 + }, + }, + { + name: "CloseWait", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait + }, + }, + { + name: "LastAck", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait + tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck + }, + }, + } + + // Use a unique source port per subtest so nothing aliases. + port := uint16(12345) + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout) + require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout) + require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout) + + srcIP := netip.MustParseAddr("100.64.0.1") + port++ + key, wantState := c.drive(t, tracker, srcIP, port) + conn := tracker.connections[key] + require.NotNil(t, conn) + require.Equal(t, wantState, conn.GetState()) + + // Age the entry past the largest per-state timeout. + conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano()) + tracker.cleanup() + _, exists := tracker.connections[key] + require.False(t, exists, "%s entry should be reaped", c.name) + }) + } +} + +func TestTCPEstablishedPSHACKInFinStates(t *testing.T) { + // Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN + // teardown states, which some stacks emit during close. + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Peer FIN -> CloseWait. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) + + // Peer pushes trailing data + FIN|PSH|ACK (legal). + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100), + "FIN|PSH|ACK in CloseWait must be accepted") + + // Bare ACK keepalive from peer in CloseWait must be accepted. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0), + "bare ACK in CloseWait must be accepted") +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index a3b6a418b..335c5832a 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -17,6 +17,9 @@ const ( DefaultUDPTimeout = 30 * time.Second // UDPCleanupInterval is how often we check for stale connections UDPCleanupInterval = 15 * time.Second + + // EnvUDPMaxEntries caps the UDP conntrack table size. + EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX" ) // UDPConnTrack represents a UDP connection state @@ -34,6 +37,7 @@ type UDPTracker struct { cleanupTicker *time.Ticker tickerCancel context.CancelFunc mutex sync.RWMutex + maxEntries int flowLogger nftypes.FlowLogger } @@ -51,6 +55,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), tickerCancel: cancel, + maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries), flowLogger: flowLogger, } @@ -117,13 +122,18 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d conn.UpdateCounters(direction, size) t.mutex.Lock() + if t.maxEntries > 0 && len(t.connections) >= t.maxEntries { + t.evictOneLocked() + } t.connections[key] = conn t.mutex.Unlock() - if origPort != 0 { - t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) - } else { - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if t.logger.Enabled(nblog.LevelTrace) { + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } } t.sendEvent(nftypes.TypeStart, conn, ruleID) } @@ -151,6 +161,34 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort return true } +// evictOneLocked removes one entry to make room. Caller must hold t.mutex. +// Bounded sample: picks the oldest among up to evictSampleSize entries. +func (t *UDPTracker) evictOneLocked() { + var candKey ConnKey + var candSeen int64 + haveCand := false + sampled := 0 + + for k, c := range t.connections { + seen := c.lastSeen.Load() + if !haveCand || seen < candSeen { + candKey = k + candSeen = seen + haveCand = true + } + sampled++ + if sampled >= evictSampleSize { + break + } + } + if haveCand { + if evicted := t.connections[candKey]; evicted != nil { + t.sendEvent(nftypes.TypeEnd, evicted, nil) + } + delete(t.connections, candKey) + } +} + // cleanupRoutine periodically removes stale connections func (t *UDPTracker) cleanupRoutine(ctx context.Context) { defer t.cleanupTicker.Stop() @@ -173,8 +211,10 @@ func (t *UDPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", - key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } t.sendEvent(nftypes.TypeEnd, conn, nil) } } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 5ecd08abf..91866dcab 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -787,7 +787,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { srcIP, dstIP := m.extractIPs(d) if !srcIP.IsValid() { - m.logger.Error1("Unknown network layer: %v", d.decoded[0]) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("Unknown network layer: %v", d.decoded[0]) + } return false } @@ -901,7 +903,9 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { return false } - m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue) + } return true } @@ -1044,11 +1048,13 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { // TODO: pass fragments of routed packets to forwarder if fragment { - if d.decoded[0] == layers.LayerTypeIPv4 { - m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", - srcIP, dstIP, d.ip4.Id, d.ip4.Flags) - } else { - m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP) + if m.logger.Enabled(nblog.LevelTrace) { + if d.decoded[0] == layers.LayerTypeIPv4 { + m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", + srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + } else { + m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP) + } } return false } @@ -1091,8 +1097,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleID, pnum, srcIP, srcPort, dstIP, dstPort) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", + ruleID, pnum, srcIP, srcPort, dstIP, dstPort) + } m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), @@ -1142,8 +1150,10 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { // Drop if routing is disabled if !m.routingEnabled.Load() { - m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s", - srcIP, dstIP) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s", + srcIP, dstIP) + } return true } @@ -1160,8 +1170,10 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe if !pass { proto := getProtocolFromPacket(d) - m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleID, proto, srcIP, srcPort, dstIP, dstPort) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", + ruleID, proto, srcIP, srcPort, dstIP, dstPort) + } m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), @@ -1287,7 +1299,9 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { // It returns true, true if the packet is a fragment and valid. func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { if err := d.decodePacket(packetData); err != nil { - m.logger.Trace1("couldn't decode packet, err: %s", err) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace1("couldn't decode packet, err: %s", err) + } return false, false } diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 3922c2052..d6d4e705e 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -13,6 +13,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) @@ -97,8 +98,10 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by return nil, fmt.Errorf("write ICMP packet: %w", err) } - f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", - epID(id), icmpType, icmpCode) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", + epID(id), icmpType, icmpCode) + } return conn, nil } @@ -121,12 +124,14 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp txBytes := f.handleEchoResponse(conn, id, v6) rtt := time.Since(sendTime).Round(10 * time.Microsecond) - proto := "ICMP" - if v6 { - proto = "ICMPv6" + if f.logger.Enabled(nblog.LevelTrace) { + proto := "ICMP" + if v6 { + proto = "ICMPv6" + } + f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)", + proto, epID(id), icmpType, icmpCode, rtt) } - f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)", - proto, epID(id), icmpType, icmpCode, rtt) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } @@ -224,13 +229,17 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi } rtt := time.Since(pingStart).Round(10 * time.Microsecond) - f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v", - epID(id), icmpType, icmpCode) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v", + epID(id), icmpType, icmpCode) + } txBytes := f.synthesizeEchoReply(id, icmpData) - f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)", - epID(id), icmpType, icmpCode, rtt) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)", + epID(id), icmpType, icmpCode, rtt) + } f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 8844463f5..c65ebcde0 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -1,11 +1,8 @@ package forwarder import ( - "context" - "io" "net" "strconv" - "sync" "github.com/google/uuid" @@ -15,7 +12,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + "github.com/netbirdio/netbird/util/netrelay" ) // handleTCP is called by the TCP forwarder for new connections. @@ -37,7 +36,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { r.Complete(true) - f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err) + } return } @@ -60,64 +61,22 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { inConn := gonet.NewTCPConn(&wq, ep) success = true - f.logger.Trace1("forwarder: established TCP connection %v", epID(id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: established TCP connection %v", epID(id)) + } go f.proxyTCP(id, inConn, outConn, ep, flowID) } func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { + // netrelay.Relay copies bidirectionally with proper half-close propagation + // and fully closes both conns before returning. + bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{ + Logger: f.logger, + }) - ctx, cancel := context.WithCancel(f.ctx) - defer cancel() - - go func() { - <-ctx.Done() - // Close connections and endpoint. - if err := inConn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug1("forwarder: inConn close error: %v", err) - } - if err := outConn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug1("forwarder: outConn close error: %v", err) - } - - ep.Close() - }() - - var wg sync.WaitGroup - wg.Add(2) - - var ( - bytesFromInToOut int64 // bytes from client to server (tx for client) - bytesFromOutToIn int64 // bytes from server to client (rx for client) - errInToOut error - errOutToIn error - ) - - go func() { - bytesFromInToOut, errInToOut = io.Copy(outConn, inConn) - cancel() - wg.Done() - }() - - go func() { - - bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn) - cancel() - wg.Done() - }() - - wg.Wait() - - if errInToOut != nil { - if !isClosedError(errInToOut) { - f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut) - } - } - if errOutToIn != nil { - if !isClosedError(errOutToIn) { - f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn) - } - } + // Close the netstack endpoint after both conns are drained. + ep.Close() var rxPackets, txPackets uint64 if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { @@ -126,7 +85,9 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn txPackets = tcpStats.SegmentsReceived.Value() } - f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + } f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index c92fa1f32..d840ef06b 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -125,7 +125,9 @@ func (f *udpForwarder) cleanup() { delete(f.conns, idle.id) f.Unlock() - f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) + } } } } @@ -144,7 +146,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { _, exists := f.udpForwarder.conns[id] f.udpForwarder.RUnlock() if exists { - f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) + } return true } @@ -206,7 +210,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { f.udpForwarder.Unlock() success = true - f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) + } go f.proxyUDP(connCtx, pConn, id, ep) return true @@ -265,7 +271,9 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack txPackets = udpStats.PacketsReceived.Value() } - f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + } f.udpForwarder.Lock() delete(f.udpForwarder.conns, id) diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index c6ca55e70..03e7d4809 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -53,16 +53,17 @@ var levelStrings = map[Level]string{ } type logMessage struct { - level Level - format string - arg1 any - arg2 any - arg3 any - arg4 any - arg5 any - arg6 any - arg7 any - arg8 any + level Level + argCount uint8 + format string + arg1 any + arg2 any + arg3 any + arg4 any + arg5 any + arg6 any + arg7 any + arg8 any } // Logger is a high-performance, non-blocking logger @@ -107,6 +108,13 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } +// Enabled reports whether the given level is currently logged. Callers on the +// hot path should guard log sites with this to avoid boxing arguments into +// any when the level is off. +func (l *Logger) Enabled(level Level) bool { + return l.level.Load() >= uint32(level) +} + func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { select { @@ -155,7 +163,7 @@ func (l *Logger) Trace(format string) { func (l *Logger) Error1(format string, arg1 any) { if l.level.Load() >= uint32(LevelError) { select { - case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}: + case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}: default: } } @@ -164,7 +172,16 @@ func (l *Logger) Error1(format string, arg1 any) { func (l *Logger) Error2(format string, arg1, arg2 any) { if l.level.Load() >= uint32(LevelError) { select { - case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}: + case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}: + default: + } + } +} + +func (l *Logger) Warn2(format string, arg1, arg2 any) { + if l.level.Load() >= uint32(LevelWarn) { + select { + case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}: default: } } @@ -173,7 +190,7 @@ func (l *Logger) Error2(format string, arg1, arg2 any) { func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { if l.level.Load() >= uint32(LevelWarn) { select { - case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: default: } } @@ -182,7 +199,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) { if l.level.Load() >= uint32(LevelWarn) { select { - case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: + case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: default: } } @@ -191,7 +208,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Debug1(format string, arg1 any) { if l.level.Load() >= uint32(LevelDebug) { select { - case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}: + case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}: default: } } @@ -200,7 +217,7 @@ func (l *Logger) Debug1(format string, arg1 any) { func (l *Logger) Debug2(format string, arg1, arg2 any) { if l.level.Load() >= uint32(LevelDebug) { select { - case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}: + case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}: default: } } @@ -209,16 +226,59 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) { func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { if l.level.Load() >= uint32(LevelDebug) { select { - case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: default: } } } +// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3 +// to avoid allocating an args slice on the fast path when the arg count is +// known (0-3). Args beyond 3 land on the general variadic path; callers on +// the hot path should prefer DebugN for known counts. +func (l *Logger) Debugf(format string, args ...any) { + if l.level.Load() < uint32(LevelDebug) { + return + } + switch len(args) { + case 0: + l.Debug(format) + case 1: + l.Debug1(format, args[0]) + case 2: + l.Debug2(format, args[0], args[1]) + case 3: + l.Debug3(format, args[0], args[1], args[2]) + default: + l.sendVariadic(LevelDebug, format, args) + } +} + +// sendVariadic packs a slice of arguments into a logMessage and non-blocking +// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args +// beyond the 8-arg slot limit are dropped so callers don't produce silently +// empty log lines via uint8 wraparound in argCount. +func (l *Logger) sendVariadic(level Level, format string, args []any) { + const maxArgs = 8 + n := len(args) + if n > maxArgs { + n = maxArgs + } + msg := logMessage{level: level, argCount: uint8(n), format: format} + slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8} + for i := 0; i < n; i++ { + *slots[i] = args[i] + } + select { + case l.msgChannel <- msg: + default: + } +} + func (l *Logger) Trace1(format string, arg1 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}: default: } } @@ -227,7 +287,7 @@ func (l *Logger) Trace1(format string, arg1 any) { func (l *Logger) Trace2(format string, arg1, arg2 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}: default: } } @@ -236,7 +296,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) { func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: default: } } @@ -245,7 +305,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) { func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: default: } } @@ -254,7 +314,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}: default: } } @@ -263,7 +323,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) { func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}: default: } } @@ -273,7 +333,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: default: } } @@ -286,35 +346,8 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = append(*buf, levelStrings[msg.level]...) *buf = append(*buf, ' ') - // Count non-nil arguments for switch - argCount := 0 - if msg.arg1 != nil { - argCount++ - if msg.arg2 != nil { - argCount++ - if msg.arg3 != nil { - argCount++ - if msg.arg4 != nil { - argCount++ - if msg.arg5 != nil { - argCount++ - if msg.arg6 != nil { - argCount++ - if msg.arg7 != nil { - argCount++ - if msg.arg8 != nil { - argCount++ - } - } - } - } - } - } - } - } - var formatted string - switch argCount { + switch msg.argCount { case 0: formatted = msg.format case 1: diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 0d411c21e..5d51c1538 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -11,6 +11,7 @@ import ( "github.com/google/gopacket/layers" firewall "github.com/netbirdio/netbird/client/firewall/manager" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) var ( @@ -262,11 +263,15 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { } if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil { - m.logger.Error1("failed to rewrite packet destination: %v", err) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("failed to rewrite packet destination: %v", err) + } return false } - m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP) + } return true } @@ -283,11 +288,15 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { } if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil { - m.logger.Error1("failed to rewrite packet source: %v", err) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("failed to rewrite packet source: %v", err) + } return false } - m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP) + } return true } @@ -612,7 +621,9 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti } if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { - m.logger.Error1("failed to rewrite port: %v", err) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("failed to rewrite port: %v", err) + } return false } d.dnatOrigPort = rule.origPort diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 7f72a72cf..ebf8eb794 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -25,6 +25,7 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/netrelay" ) const ( @@ -536,7 +537,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str continue } - go c.handleLocalForward(localConn, remoteAddr) + go c.handleLocalForward(ctx, localConn, remoteAddr) } }() @@ -548,7 +549,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str } // handleLocalForward handles a single local port forwarding connection -func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { +func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) { defer func() { if err := localConn.Close(); err != nil { log.Debugf("local port forwarding: close local connection: %v", err) @@ -571,7 +572,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { } }() - nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) + netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } // RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr @@ -653,16 +654,19 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri select { case <-ctx.Done(): return - case newChan := <-channelRequests: + case newChan, ok := <-channelRequests: + if !ok { + return + } if newChan != nil { - go c.handleRemoteForwardChannel(newChan, localAddr) + go c.handleRemoteForwardChannel(ctx, newChan, localAddr) } } } } // handleRemoteForwardChannel handles a single forwarded-tcpip channel -func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) { +func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) { channel, reqs, err := newChan.Accept() if err != nil { return @@ -675,8 +679,14 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st go ssh.DiscardRequests(reqs) - localConn, err := net.Dial("tcp", localAddr) + // Bound the dial so a black-holed localAddr can't pin the accepted SSH + // channel open indefinitely; the relay itself runs under the outer ctx. + dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second) + var dialer net.Dialer + localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr) + cancelDial() if err != nil { + log.Debugf("remote port forwarding: dial %s: %v", localAddr, err) return } defer func() { @@ -685,7 +695,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st } }() - nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) + netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } // tcpipForwardMsg represents the structure for tcpip-forward requests diff --git a/client/ssh/common.go b/client/ssh/common.go index f6aec5f9c..92e647b7d 100644 --- a/client/ssh/common.go +++ b/client/ssh/common.go @@ -194,63 +194,3 @@ func buildAddressList(hostname string, remote net.Addr) []string { return addresses } -// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections. -// It waits for both directions to complete before returning. -// The caller is responsible for closing the connections. -func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) { - done := make(chan struct{}, 2) - - go func() { - if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (1->2): %v", err) - } - done <- struct{}{} - }() - - go func() { - if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (2->1): %v", err) - } - done <- struct{}{} - }() - - <-done - <-done -} - -func isExpectedCopyError(err error) bool { - return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) -} - -// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections. -// It waits for both directions to complete or for context cancellation before returning. -// Both connections are closed when the function returns. -func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) { - done := make(chan struct{}, 2) - - go func() { - if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (1->2): %v", err) - } - done <- struct{}{} - }() - - go func() { - if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (2->1): %v", err) - } - done <- struct{}{} - }() - - select { - case <-ctx.Done(): - case <-done: - select { - case <-ctx.Done(): - case <-done: - } - } - - _ = conn1.Close() - _ = conn2.Close() -} diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index eb659fe21..73b50122c 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/proto" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/util/netrelay" "github.com/netbirdio/netbird/version" ) @@ -352,7 +353,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne } go cryptossh.DiscardRequests(clientReqs) - nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) + netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { @@ -591,7 +592,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh } go cryptossh.DiscardRequests(clientReqs) - nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) + netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index f5ac66fca..a47fdb48a 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -17,7 +17,7 @@ import ( log "github.com/sirupsen/logrus" cryptossh "golang.org/x/crypto/ssh" - nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/util/netrelay" ) const privilegedPortThreshold = 1024 @@ -357,7 +357,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h return } - nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel) + netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger}) } // openForwardChannel creates an SSH forwarded-tcpip channel diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index de40d3091..6735e0f3b 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -8,9 +8,9 @@ import ( "fmt" "io" "net" - "strconv" "net/netip" "slices" + "strconv" "strings" "sync" "time" @@ -27,6 +27,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth/jwt" + "github.com/netbirdio/netbird/util/netrelay" "github.com/netbirdio/netbird/version" ) @@ -53,6 +54,10 @@ const ( DefaultJWTMaxTokenAge = 10 * 60 ) +// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to +// the forwarded destination before rejecting the SSH channel. +const directTCPIPDialTimeout = 30 * time.Second + var ( ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled) ErrUserNotFound = errors.New("user not found") @@ -933,5 +938,29 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) logger.Infof("local port forwarding: %s", hostPort) - ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) + s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger) +} + +// relayDirectTCPIP is a netrelay-based replacement for gliderlabs' +// DirectTCPIPHandler. The upstream handler closes both sides on the first +// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its +// own terms. +func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) { + dest := net.JoinHostPort(host, strconv.Itoa(port)) + + dialer := net.Dialer{Timeout: directTCPIPDialTimeout} + dconn, err := dialer.DialContext(ctx, "tcp", dest) + if err != nil { + _ = newChan.Reject(cryptossh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + _ = dconn.Close() + return + } + go cryptossh.DiscardRequests(reqs) + + netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger}) } diff --git a/proxy/internal/tcp/peekedconn.go b/proxy/internal/tcp/peekedconn.go index 26f3e5c7c..23a348352 100644 --- a/proxy/internal/tcp/peekedconn.go +++ b/proxy/internal/tcp/peekedconn.go @@ -25,6 +25,12 @@ func (c *peekedConn) Read(b []byte) (int, error) { return c.reader.Read(b) } +// halfCloser matches connections that support shutting down the write +// side while keeping the read side open (e.g. *net.TCPConn). +type halfCloser interface { + CloseWrite() error +} + // CloseWrite delegates to the underlying connection if it supports // half-close (e.g. *net.TCPConn). Without this, embedding net.Conn // as an interface hides the concrete type's CloseWrite method, making diff --git a/proxy/internal/tcp/relay.go b/proxy/internal/tcp/relay.go deleted file mode 100644 index 39949818d..000000000 --- a/proxy/internal/tcp/relay.go +++ /dev/null @@ -1,156 +0,0 @@ -package tcp - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/proxy/internal/netutil" -) - -// errIdleTimeout is returned when a relay connection is closed due to inactivity. -var errIdleTimeout = errors.New("idle timeout") - -// DefaultIdleTimeout is the default idle timeout for TCP relay connections. -// A zero value disables idle timeout checking. -const DefaultIdleTimeout = 5 * time.Minute - -// halfCloser is implemented by connections that support half-close -// (e.g. *net.TCPConn). When one copy direction finishes, we signal -// EOF to the remote by closing the write side while keeping the read -// side open so the other direction can drain. -type halfCloser interface { - CloseWrite() error -} - -// copyBufPool avoids allocating a new 32KB buffer per io.Copy call. -var copyBufPool = sync.Pool{ - New: func() any { - buf := make([]byte, 32*1024) - return &buf - }, -} - -// Relay copies data bidirectionally between src and dst until both -// sides are done or the context is canceled. When idleTimeout is -// non-zero, each direction's read is deadline-guarded; if no data -// flows within the timeout the connection is torn down. When one -// direction finishes, it half-closes the write side of the -// destination (if supported) to signal EOF, allowing the other -// direction to drain gracefully before the full connection teardown. -func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - go func() { - <-ctx.Done() - _ = src.Close() - _ = dst.Close() - }() - - var wg sync.WaitGroup - wg.Add(2) - - var errSrcToDst, errDstToSrc error - - go func() { - defer wg.Done() - srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout) - halfClose(dst) - cancel() - }() - - go func() { - defer wg.Done() - dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout) - halfClose(src) - cancel() - }() - - wg.Wait() - - if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) { - logger.Debug("relay closed due to idle timeout") - } - if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) { - logger.Debugf("relay copy error (src→dst): %v", errSrcToDst) - } - if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) { - logger.Debugf("relay copy error (dst→src): %v", errDstToSrc) - } - - return srcToDst, dstToSrc -} - -// copyWithIdleTimeout copies from src to dst using a pooled buffer. -// When idleTimeout > 0 it sets a read deadline on src before each -// read and treats a timeout as an idle-triggered close. -func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) { - bufp := copyBufPool.Get().(*[]byte) - defer copyBufPool.Put(bufp) - - if idleTimeout <= 0 { - return io.CopyBuffer(dst, src, *bufp) - } - - conn, ok := src.(net.Conn) - if !ok { - return io.CopyBuffer(dst, src, *bufp) - } - - buf := *bufp - var total int64 - for { - if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { - return total, err - } - nr, readErr := src.Read(buf) - if nr > 0 { - n, err := checkedWrite(dst, buf[:nr]) - total += n - if err != nil { - return total, err - } - } - if readErr != nil { - if netutil.IsTimeout(readErr) { - return total, errIdleTimeout - } - return total, readErr - } - } -} - -// checkedWrite writes buf to dst and returns the number of bytes written. -// It guards against short writes and negative counts per io.Copy convention. -func checkedWrite(dst io.Writer, buf []byte) (int64, error) { - nw, err := dst.Write(buf) - if nw < 0 || nw > len(buf) { - nw = 0 - } - if err != nil { - return int64(nw), err - } - if nw != len(buf) { - return int64(nw), io.ErrShortWrite - } - return int64(nw), nil -} - -func isExpectedCopyError(err error) bool { - return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err) -} - -// halfClose attempts to half-close the write side of the connection. -// If the connection does not support half-close, this is a no-op. -func halfClose(conn net.Conn) { - if hc, ok := conn.(halfCloser); ok { - // Best-effort; the full close will follow shortly. - _ = hc.CloseWrite() - } -} diff --git a/proxy/internal/tcp/relay_test.go b/proxy/internal/tcp/relay_test.go index e42d65b9d..f83a0d155 100644 --- a/proxy/internal/tcp/relay_test.go +++ b/proxy/internal/tcp/relay_test.go @@ -13,8 +13,13 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/proxy/internal/netutil" + "github.com/netbirdio/netbird/util/netrelay" ) +func testRelay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (int64, int64) { + return netrelay.Relay(ctx, src, dst, netrelay.Options{IdleTimeout: idleTimeout, Logger: logger}) +} + func TestRelay_BidirectionalCopy(t *testing.T) { srcClient, srcServer := net.Pipe() dstClient, dstServer := net.Pipe() @@ -41,7 +46,7 @@ func TestRelay_BidirectionalCopy(t *testing.T) { srcClient.Close() }() - s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0) + s2d, d2s := testRelay(ctx, logger, srcServer, dstServer, 0) assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst") assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src") @@ -58,7 +63,7 @@ func TestRelay_ContextCancellation(t *testing.T) { done := make(chan struct{}) go func() { - Relay(ctx, logger, srcServer, dstServer, 0) + testRelay(ctx, logger, srcServer, dstServer, 0) close(done) }() @@ -85,7 +90,7 @@ func TestRelay_OneSideClosed(t *testing.T) { done := make(chan struct{}) go func() { - Relay(ctx, logger, srcServer, dstServer, 0) + testRelay(ctx, logger, srcServer, dstServer, 0) close(done) }() @@ -129,7 +134,7 @@ func TestRelay_LargeTransfer(t *testing.T) { dstClient.Close() }() - s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0) + s2d, _ := testRelay(ctx, logger, srcServer, dstServer, 0) assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes") require.NoError(t, <-errCh) } @@ -182,7 +187,7 @@ func TestRelay_IdleTimeout(t *testing.T) { done := make(chan struct{}) var s2d, d2s int64 go func() { - s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond) + s2d, d2s = testRelay(ctx, logger, srcServer, dstServer, 200*time.Millisecond) close(done) }() diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go index 9f8660aeb..05beb658b 100644 --- a/proxy/internal/tcp/router.go +++ b/proxy/internal/tcp/router.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/util/netrelay" ) // defaultDialTimeout is the fallback dial timeout when no per-route @@ -528,11 +529,14 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route idleTimeout := route.SessionIdleTimeout if idleTimeout <= 0 { - idleTimeout = DefaultIdleTimeout + idleTimeout = netrelay.DefaultIdleTimeout } start := time.Now() - s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout) + s2d, d2s := netrelay.Relay(svcCtx, conn, backend, netrelay.Options{ + IdleTimeout: idleTimeout, + Logger: entry, + }) elapsed := time.Since(start) if obs != nil { diff --git a/util/netrelay/relay.go b/util/netrelay/relay.go new file mode 100644 index 000000000..de44d5bcd --- /dev/null +++ b/util/netrelay/relay.go @@ -0,0 +1,238 @@ +// Package netrelay provides a bidirectional byte-copy helper for TCP-like +// connections with correct half-close propagation. +// +// When one direction reads EOF, the write side of the opposite connection is +// half-closed (CloseWrite) so the peer sees FIN, then the second direction is +// allowed to drain to its own EOF before both connections are fully closed. +// This preserves TCP half-close semantics (e.g. shutdown(SHUT_WR)) that the +// naive "cancel-both-on-first-EOF" pattern breaks. +package netrelay + +import ( + "context" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "syscall" + "time" +) + +// DebugLogger is the minimal interface netrelay uses to surface teardown +// errors. Both *logrus.Entry and *nblog.Logger (via its Debugf method) +// satisfy it, so callers can pass whichever they already use without an +// adapter. Debugf is the only required method; callers with richer +// loggers just expose this one shape here. +type DebugLogger interface { + Debugf(format string, args ...any) +} + +// DefaultIdleTimeout is a reasonable default for Options.IdleTimeout. Callers +// that want an idle timeout but have no specific preference can use this. +const DefaultIdleTimeout = 5 * time.Minute + +// halfCloser is implemented by connections that support half-close +// (e.g. *net.TCPConn, *gonet.TCPConn). +type halfCloser interface { + CloseWrite() error +} + +var copyBufPool = sync.Pool{ + New: func() any { + buf := make([]byte, 32*1024) + return &buf + }, +} + +// Options configures Relay behavior. The zero value is valid: no idle timeout, +// no logging. +type Options struct { + // IdleTimeout tears down the session if no bytes flow in either + // direction within this window. It is a connection-wide watchdog, so a + // long unidirectional transfer on one side keeps the other side alive. + // Zero disables idle tracking. + IdleTimeout time.Duration + // Logger receives debug-level copy/idle errors. Nil suppresses logging. + // Any logger with Debug/Debugf methods is accepted (logrus.Entry, + // uspfilter's nblog.Logger, etc.). + Logger DebugLogger +} + +// Relay copies bytes in both directions between a and b until both directions +// EOF or ctx is canceled. On each direction's EOF it half-closes the +// opposite conn's write side (best effort) so the peer sees FIN while the +// other direction drains. Both conns are fully closed when Relay returns. +// +// a and b only need to implement io.ReadWriteCloser; connections that also +// implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close +// propagation. Options.IdleTimeout, when set, is enforced by a connection-wide +// watchdog that tracks reads in either direction. +// +// Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read → +// a.Write). Errors are logged via Options.Logger when set; they are not +// returned because a relay always terminates on some kind of EOF/cancel. +func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bToA int64) { + ctx, cancel := context.WithCancel(ctx) + closeDone := make(chan struct{}) + defer func() { + cancel() + <-closeDone + }() + + go func() { + <-ctx.Done() + _ = a.Close() + _ = b.Close() + close(closeDone) + }() + + // Both sides must support CloseWrite to propagate half-close. If either + // doesn't, a direction's EOF can't be signaled to the peer and the other + // direction would block forever waiting for data; in that case we fall + // back to the cancel-both-on-first-EOF behavior. + _, aHC := a.(halfCloser) + _, bHC := b.(halfCloser) + halfCloseSupported := aHC && bHC + + var ( + lastActivity atomic.Int64 + idleHit atomic.Bool + ) + lastActivity.Store(time.Now().UnixNano()) + + if opts.IdleTimeout > 0 { + go watchdog(ctx, cancel, &lastActivity, &idleHit, opts.IdleTimeout) + } + + var wg sync.WaitGroup + wg.Add(2) + + var errAToB, errBToA error + + go func() { + defer wg.Done() + aToB, errAToB = copyTracked(b, a, &lastActivity) + if halfCloseSupported && isCleanEOF(errAToB) { + halfClose(b) + } else { + cancel() + } + }() + + go func() { + defer wg.Done() + bToA, errBToA = copyTracked(a, b, &lastActivity) + if halfCloseSupported && isCleanEOF(errBToA) { + halfClose(a) + } else { + cancel() + } + }() + + wg.Wait() + + if opts.Logger != nil { + if idleHit.Load() { + opts.Logger.Debugf("relay closed due to idle timeout") + } + if errAToB != nil && !isExpectedCopyError(errAToB) { + opts.Logger.Debugf("relay copy error (a→b): %v", errAToB) + } + if errBToA != nil && !isExpectedCopyError(errBToA) { + opts.Logger.Debugf("relay copy error (b→a): %v", errBToA) + } + } + + return aToB, bToA +} + +// watchdog enforces a connection-wide idle timeout. It cancels ctx when no +// activity has been seen on either direction for idle. It exits as soon as +// ctx is canceled so it doesn't outlive the relay. +func watchdog(ctx context.Context, cancel context.CancelFunc, lastActivity *atomic.Int64, idleHit *atomic.Bool, idle time.Duration) { + // Cap the tick at 50ms so detection latency stays bounded regardless of + // how large idle is, and fall back to idle/2 when that is smaller so + // very short timeouts (mainly in tests) are still caught promptly. + tick := min(idle/2, 50*time.Millisecond) + if tick <= 0 { + tick = time.Millisecond + } + t := time.NewTicker(tick) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + last := time.Unix(0, lastActivity.Load()) + if time.Since(last) >= idle { + idleHit.Store(true) + cancel() + return + } + } + } +} + +// copyTracked copies from src to dst using a pooled buffer, updating +// lastActivity on every successful read so a shared watchdog can enforce a +// connection-wide idle timeout. +func copyTracked(dst io.Writer, src io.Reader, lastActivity *atomic.Int64) (int64, error) { + bufp := copyBufPool.Get().(*[]byte) + defer copyBufPool.Put(bufp) + + buf := *bufp + var total int64 + for { + nr, readErr := src.Read(buf) + if nr > 0 { + lastActivity.Store(time.Now().UnixNano()) + n, werr := checkedWrite(dst, buf[:nr]) + total += n + if werr != nil { + return total, werr + } + } + if readErr != nil { + return total, readErr + } + } +} + +func checkedWrite(dst io.Writer, buf []byte) (int64, error) { + nw, err := dst.Write(buf) + if nw < 0 || nw > len(buf) { + nw = 0 + } + if err != nil { + return int64(nw), err + } + if nw != len(buf) { + return int64(nw), io.ErrShortWrite + } + return int64(nw), nil +} + +func halfClose(conn io.ReadWriteCloser) { + if hc, ok := conn.(halfCloser); ok { + _ = hc.CloseWrite() + } +} + +// isCleanEOF reports whether a copy terminated on a graceful end-of-stream. +// Only in that case is it correct to propagate the EOF via CloseWrite on the +// peer; any other error means the flow is broken and both directions should +// tear down. +func isCleanEOF(err error) bool { + return err == nil || errors.Is(err, io.EOF) +} + +func isExpectedCopyError(err error) bool { + return errors.Is(err, net.ErrClosed) || + errors.Is(err, context.Canceled) || + errors.Is(err, io.EOF) || + errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ECONNABORTED) +} diff --git a/util/netrelay/relay_test.go b/util/netrelay/relay_test.go new file mode 100644 index 000000000..0cb86eb0d --- /dev/null +++ b/util/netrelay/relay_test.go @@ -0,0 +1,221 @@ +package netrelay + +import ( + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// tcpPair returns two connected loopback TCP conns. +func tcpPair(t *testing.T) (*net.TCPConn, *net.TCPConn) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + type result struct { + c *net.TCPConn + err error + } + ch := make(chan result, 1) + go func() { + c, err := ln.Accept() + if err != nil { + ch <- result{nil, err} + return + } + ch <- result{c.(*net.TCPConn), nil} + }() + + dial, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + + r := <-ch + require.NoError(t, r.err) + return dial.(*net.TCPConn), r.c +} + +// TestRelayHalfClose exercises the shutdown(SHUT_WR) scenario that the naive +// cancel-both-on-first-EOF pattern breaks. Client A shuts down its write +// side; B must still be able to write a full response and A must receive +// all of it before its read returns EOF. +func TestRelayHalfClose(t *testing.T) { + // Real peer pairs for each side of the relay. We relay between relayA + // and relayB. Peer A talks through relayA; peer B talks through relayB. + peerA, relayA := tcpPair(t) + relayB, peerB := tcpPair(t) + + defer peerA.Close() + defer peerB.Close() + + // Bound blocking reads/writes so a broken relay fails the test instead of + // hanging the test process. + deadline := time.Now().Add(5 * time.Second) + require.NoError(t, peerA.SetDeadline(deadline)) + require.NoError(t, peerB.SetDeadline(deadline)) + + ctx := t.Context() + + done := make(chan struct{}) + go func() { + Relay(ctx, relayA, relayB, Options{}) + close(done) + }() + + // Peer A sends a request, then half-closes its write side. + req := []byte("request-payload") + _, err := peerA.Write(req) + require.NoError(t, err) + require.NoError(t, peerA.CloseWrite()) + + // Peer B reads the request to EOF (FIN must have propagated). + got, err := io.ReadAll(peerB) + require.NoError(t, err) + require.Equal(t, req, got) + + // Peer B writes its response; peer A must receive all of it even though + // peer A's write side is already closed. + resp := make([]byte, 64*1024) + for i := range resp { + resp[i] = byte(i) + } + _, err = peerB.Write(resp) + require.NoError(t, err) + require.NoError(t, peerB.Close()) + + gotResp, err := io.ReadAll(peerA) + require.NoError(t, err) + require.Equal(t, resp, gotResp) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("relay did not return") + } +} + +// TestRelayFullDuplex verifies bidirectional copy in the simple case. +func TestRelayFullDuplex(t *testing.T) { + peerA, relayA := tcpPair(t) + relayB, peerB := tcpPair(t) + defer peerA.Close() + defer peerB.Close() + + // Bound blocking reads/writes so a broken relay fails the test instead of + // hanging the test process. + deadline := time.Now().Add(5 * time.Second) + require.NoError(t, peerA.SetDeadline(deadline)) + require.NoError(t, peerB.SetDeadline(deadline)) + + ctx := t.Context() + + done := make(chan struct{}) + go func() { + Relay(ctx, relayA, relayB, Options{}) + close(done) + }() + + type result struct { + got []byte + err error + } + resA := make(chan result, 1) + resB := make(chan result, 1) + + msgAB := []byte("hello-from-a") + msgBA := []byte("hello-from-b") + + go func() { + if _, err := peerA.Write(msgAB); err != nil { + resA <- result{err: err} + return + } + buf := make([]byte, len(msgBA)) + _, err := io.ReadFull(peerA, buf) + resA <- result{got: buf, err: err} + _ = peerA.Close() + }() + + go func() { + if _, err := peerB.Write(msgBA); err != nil { + resB <- result{err: err} + return + } + buf := make([]byte, len(msgAB)) + _, err := io.ReadFull(peerB, buf) + resB <- result{got: buf, err: err} + _ = peerB.Close() + }() + + a, b := <-resA, <-resB + require.NoError(t, a.err) + require.Equal(t, msgBA, a.got) + require.NoError(t, b.err) + require.Equal(t, msgAB, b.got) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("relay did not return") + } +} + +// TestRelayNoHalfCloseFallback ensures Relay terminates when the underlying +// conns don't support CloseWrite (e.g. net.Pipe). Without the fallback to +// cancel-both-on-first-EOF, the second direction would block forever. +func TestRelayNoHalfCloseFallback(t *testing.T) { + a1, a2 := net.Pipe() + b1, b2 := net.Pipe() + defer a1.Close() + defer b1.Close() + + ctx := t.Context() + done := make(chan struct{}) + go func() { + Relay(ctx, a2, b2, Options{}) + close(done) + }() + + // Close peer A's side; a2's Read will return EOF. + require.NoError(t, a1.Close()) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("relay did not terminate when half-close is unsupported") + } +} + +// TestRelayIdleTimeout ensures the idle watchdog tears down a silent flow. +func TestRelayIdleTimeout(t *testing.T) { + peerA, relayA := tcpPair(t) + relayB, peerB := tcpPair(t) + defer peerA.Close() + defer peerB.Close() + + ctx := t.Context() + + const idle = 150 * time.Millisecond + + start := time.Now() + done := make(chan struct{}) + go func() { + Relay(ctx, relayA, relayB, Options{IdleTimeout: idle}) + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("relay did not close on idle") + } + + elapsed := time.Since(start) + require.GreaterOrEqual(t, elapsed, idle, + "relay must not close before the idle timeout elapses") + require.Less(t, elapsed, idle+500*time.Millisecond, + "relay should close shortly after the idle timeout") +} From 6b08e89c7bb318f51c24b467a34fe05e13e17fcf Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 11 May 2026 16:59:33 +0900 Subject: [PATCH 14/27] [relay] Preserve non-standard port in WS dialer URL prep (#6061) --- shared/relay/client/dialer/ws/ws.go | 32 ++++++---- shared/relay/client/dialer/ws/ws_test.go | 76 ++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 shared/relay/client/dialer/ws/ws_test.go diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 301486514..8a13ba126 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -9,7 +9,6 @@ import ( "net" "net/http" "net/url" - "strings" "github.com/coder/websocket" log "github.com/sirupsen/logrus" @@ -35,13 +34,7 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, var underlying net.Conn opts := createDialOptions(serverName, &underlying) - parsedURL, err := url.Parse(wsURL) - if err != nil { - return nil, err - } - parsedURL.Path = relay.WebSocketURLPath - - wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts) + wsConn, resp, err := websocket.Dial(ctx, wsURL, opts) if err != nil { if errors.Is(err, context.Canceled) { return nil, err @@ -57,12 +50,27 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, return conn, nil } +// prepareURL rewrites a rel://host[:port] or rels://host[:port] address into a +// ws://host[:port]/relay or wss://host[:port]/relay URL, preserving any +// non-standard port from the input. func prepareURL(address string) (string, error) { - if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") { - return "", fmt.Errorf("unsupported scheme: %s", address) + parsed, err := url.Parse(address) + if err != nil { + return "", fmt.Errorf("parse relay address %q: %w", address, err) } - - return strings.Replace(address, "rel", "ws", 1), nil + switch parsed.Scheme { + case "rel": + parsed.Scheme = "ws" + case "rels": + parsed.Scheme = "wss" + default: + return "", fmt.Errorf("unsupported scheme: %s", parsed.Scheme) + } + if parsed.Host == "" { + return "", fmt.Errorf("missing host in relay address %q", address) + } + parsed.Path = relay.WebSocketURLPath + return parsed.String(), nil } // httpClientNbDialer builds the http client used by the websocket library. diff --git a/shared/relay/client/dialer/ws/ws_test.go b/shared/relay/client/dialer/ws/ws_test.go new file mode 100644 index 000000000..7357adbc0 --- /dev/null +++ b/shared/relay/client/dialer/ws/ws_test.go @@ -0,0 +1,76 @@ +package ws + +import ( + "testing" +) + +func TestPrepareURL(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + { + name: "rel scheme with non-standard port", + input: "rel://test-domain-2:45678", + want: "ws://test-domain-2:45678/relay", + }, + { + name: "rels scheme with non-standard port", + input: "rels://test-domain-2:45678", + want: "wss://test-domain-2:45678/relay", + }, + { + name: "rel scheme without port", + input: "rel://test-domain-2", + want: "ws://test-domain-2/relay", + }, + { + name: "rels scheme without port", + input: "rels://test-domain-2", + want: "wss://test-domain-2/relay", + }, + { + name: "rel scheme with IP and port", + input: "rel://1.2.3.4:45678", + want: "ws://1.2.3.4:45678/relay", + }, + { + name: "rel scheme with hostname starting with rel", + input: "rel://relay.example.com:45678", + want: "ws://relay.example.com:45678/relay", + }, + { + name: "rel scheme with IPv6 and port", + input: "rel://[2001:db8::1]:45678", + want: "ws://[2001:db8::1]:45678/relay", + }, + { + name: "rels scheme with IPv6 loopback and port", + input: "rels://[::1]:45678", + want: "wss://[::1]:45678/relay", + }, + { + name: "unsupported scheme", + input: "http://test-domain-2:45678", + wantErr: true, + }, + { + name: "no scheme", + input: "test-domain-2:45678", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := prepareURL(tt.input) + if (err != nil) != tt.wantErr { + t.Fatalf("prepareURL(%q) err = %v, wantErr %v", tt.input, err, tt.wantErr) + } + if got != tt.want { + t.Errorf("prepareURL(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} From a4114a5e453bbbe287610ed0510de79ea057905d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 11 May 2026 17:00:23 +0900 Subject: [PATCH 15/27] [client] Skip DNS upstream failover on definitive EDE (#6089) --- .github/workflows/golangci-lint.yml | 2 +- client/internal/dns/upstream.go | 99 +++++++++++++++++++- client/internal/dns/upstream_test.go | 129 +++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 4 +- util/capture/text.go | 33 ++++++- 6 files changed, 261 insertions(+), 8 deletions(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 62dfe9bce..7b7b32ec0 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals skip: go.mod,go.sum,**/proxy/web/** golangci: strategy: diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index a26536f6e..39064f26c 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -30,6 +30,27 @@ import ( var currentMTU uint16 = iface.DefaultMTU +// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL +// from one upstream means another upstream would return the same answer: +// DNSSEC validation outcomes and policy-based blocks. Transient errors +// (network, cached, not ready) are not included. +var nonRetryableEDECodes = map[uint16]struct{}{ + dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {}, + dns.ExtendedErrorCodeUnsupportedDSDigestType: {}, + dns.ExtendedErrorCodeDNSSECIndeterminate: {}, + dns.ExtendedErrorCodeDNSBogus: {}, + dns.ExtendedErrorCodeSignatureExpired: {}, + dns.ExtendedErrorCodeSignatureNotYetValid: {}, + dns.ExtendedErrorCodeDNSKEYMissing: {}, + dns.ExtendedErrorCodeRRSIGsMissing: {}, + dns.ExtendedErrorCodeNoZoneKeyBitSet: {}, + dns.ExtendedErrorCodeNSECMissing: {}, + dns.ExtendedErrorCodeBlocked: {}, + dns.ExtendedErrorCodeCensored: {}, + dns.ExtendedErrorCodeFiltered: {}, + dns.ExtendedErrorCodeProhibited: {}, +} + // privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate. type privateClientIface interface { Name() string @@ -250,6 +271,18 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re var t time.Duration var err error + // Advertise EDNS0 so the upstream may include Extended DNS Errors + // (RFC 8914) in failure responses; we use those to short-circuit + // failover for definitive answers like DNSSEC validation failures. + // Operate on a copy so the inbound request is unchanged: a client that + // did not advertise EDNS0 must not see an OPT in the response. + hadEdns := r.IsEdns0() != nil + reqUp := r + if !hadEdns { + reqUp = r.Copy() + reqUp.SetEdns0(upstreamUDPSize(), false) + } + var startTime time.Time var upstreamProto *upstreamProtocolResult func() { @@ -257,7 +290,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re defer cancel() ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) startTime = time.Now() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) + rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp) }() if err != nil { @@ -269,13 +302,49 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re } if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { + if code, ok := nonRetryableEDE(rm); ok { + resutil.SetMeta(w, "ede", edeName(code)) + if !hadEdns { + stripOPT(rm) + } + u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) + return nil + } return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} } + if !hadEdns { + stripOPT(rm) + } u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) return nil } +// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams, +// derived from the tunnel MTU and bounded against underflow. +func upstreamUDPSize() uint16 { + if currentMTU > ipUDPHeaderSize { + return currentMTU - ipUDPHeaderSize + } + return dns.MinMsgSize +} + +// stripOPT removes any OPT pseudo-RRs from the response's Extra section so +// the response complies with RFC 6891 when the client did not advertise EDNS0. +func stripOPT(rm *dns.Msg) { + if len(rm.Extra) == 0 { + return + } + out := rm.Extra[:0] + for _, rr := range rm.Extra { + if _, ok := rr.(*dns.OPT); ok { + continue + } + out = append(out, rr) + } + rm.Extra = out +} + func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure { if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { return &upstreamFailure{upstream: upstream, reason: err.Error()} @@ -337,6 +406,34 @@ func formatFailures(failures []upstreamFailure) string { return strings.Join(parts, ", ") } +// nonRetryableEDE returns the first non-retryable EDE code carried in the +// response, if any. +func nonRetryableEDE(rm *dns.Msg) (uint16, bool) { + opt := rm.IsEdns0() + if opt == nil { + return 0, false + } + for _, o := range opt.Option { + ede, ok := o.(*dns.EDNS0_EDE) + if !ok { + continue + } + if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok { + return ede.InfoCode, true + } + } + return 0, false +} + +// edeName returns a human-readable name for an EDE code, falling back to +// the numeric code when unknown. +func edeName(code uint16) string { + if name, ok := dns.ExtendedErrorCodeToString[code]; ok { + return name + } + return fmt.Sprintf("EDE %d", code) +} + // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) { diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 1797fdad8..d6aec05ca 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -770,3 +770,132 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) { assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records") assert.True(t, rm2.Truncated, "response should be truncated for small buffer client") } + +func msgWithEDE(rcode int, codes ...uint16) *dns.Msg { + m := new(dns.Msg) + m.Response = true + m.Rcode = rcode + if len(codes) == 0 { + return m + } + opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + opt.SetUDPSize(dns.MinMsgSize) + for _, c := range codes { + opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c}) + } + m.Extra = append(m.Extra, opt) + return m +} + +func TestNonRetryableEDE(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + wantOK bool + wantCode uint16 + }{ + {name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)}, + { + name: "opt without ede", + msg: func() *dns.Msg { + m := msgWithEDE(dns.RcodeServerFailure) + opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID}) + m.Extra = []dns.RR{opt} + return m + }(), + }, + {name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus}, + {name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired}, + {name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked}, + {name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited}, + {name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)}, + {name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)}, + {name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)}, + { + name: "first non-retryable wins", + msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus), + wantOK: true, + wantCode: dns.ExtendedErrorCodeDNSBogus, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + code, ok := nonRetryableEDE(tc.msg) + assert.Equal(t, tc.wantOK, ok, "ok should match") + if tc.wantOK { + assert.Equal(t, tc.wantCode, code, "code should match") + } + }) + } +} + +func TestEDEName(t *testing.T) { + assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus)) + assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired)) + assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric") +} + +func TestStripOPT(t *testing.T) { + rm := &dns.Msg{ + Extra: []dns.RR{ + &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}, + &dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)}, + }, + } + stripOPT(rm) + assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept") + _, isOPT := rm.Extra[0].(*dns.OPT) + assert.False(t, isOPT, "remaining record must not be OPT") +} + +func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) { + upstream1 := netip.MustParseAddrPort("192.0.2.1:53") + upstream2 := netip.MustParseAddrPort("192.0.2.2:53") + + servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus) + successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100") + + var queried []string + tracking := &trackingMockClient{ + inner: &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + upstream1.String(): {msg: servfailWithEDE}, + upstream2.String(): {msg: successResp}, + }, + rtt: time.Millisecond, + }, + queriedUpstreams: &queried, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: tracking, + upstreamServers: []netip.AddrPort{upstream1, upstream2}, + upstreamTimeout: UpstreamTimeout, + } + + var written *dns.Msg + w := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + written = m + return nil + }, + } + + // Client query without EDNS0 must not see an OPT in the response. + q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + resolver.ServeDNS(w, q) + + require.NotNil(t, written, "response must be written") + assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate") + assert.Len(t, queried, 1, "only first upstream should be queried") + assert.Equal(t, upstream1.String(), queried[0]) + for _, rr := range written.Extra { + _, isOPT := rr.(*dns.OPT) + assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client") + } +} diff --git a/go.mod b/go.mod index 84aeab941..5704887ce 100644 --- a/go.mod +++ b/go.mod @@ -72,7 +72,7 @@ require ( github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 github.com/mdp/qrterminal/v3 v3.2.1 - github.com/miekg/dns v1.1.59 + github.com/miekg/dns v1.1.72 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 diff --git a/go.sum b/go.sum index 851d1ce66..42652169c 100644 --- a/go.sum +++ b/go.sum @@ -455,8 +455,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= -github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= -github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= +github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= +github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= diff --git a/util/capture/text.go b/util/capture/text.go index fbb26654e..a6a6dd28b 100644 --- a/util/capture/text.go +++ b/util/capture/text.go @@ -12,6 +12,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/miekg/dns" ) // TextWriter writes human-readable one-line-per-packet summaries. @@ -594,19 +595,45 @@ func formatDNSResponse(d *layers.DNS, rd string, plen int) string { anCount := d.ANCount nsCount := d.NSCount arCount := d.ARCount + ede := formatEDE(d) if d.ResponseCode != layers.DNSResponseCodeNoErr { - return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen) + return fmt.Sprintf("%04x %d/%d/%d %s%s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, ede, plen) } if anCount > 0 && len(d.Answers) > 0 { rr := d.Answers[0] if rdata := shortRData(&rr); rdata != "" { - return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen) + return fmt.Sprintf("%04x %d/%d/%d %s %s%s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, ede, plen) } } - return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen) + return fmt.Sprintf("%04x %d/%d/%d%s (%d)", d.ID, anCount, nsCount, arCount, ede, plen) +} + +// dnsOPTCodeEDE is the EDNS0 option code for Extended DNS Errors (RFC 8914). +const dnsOPTCodeEDE layers.DNSOptionCode = layers.DNSOptionCode(dns.EDNS0EDE) + +// formatEDE returns " EDE=Name" for the first Extended DNS Error option +// found in the response, or empty string if none is present. +func formatEDE(d *layers.DNS) string { + for _, rr := range d.Additionals { + if rr.Type != layers.DNSTypeOPT { + continue + } + for _, opt := range rr.OPT { + if opt.Code != dnsOPTCodeEDE || len(opt.Data) < 2 { + continue + } + info := binary.BigEndian.Uint16(opt.Data[:2]) + name, ok := dns.ExtendedErrorCodeToString[info] + if !ok { + name = fmt.Sprintf("%d", info) + } + return " EDE=" + name + } + } + return "" } func shortRData(rr *layers.DNSResourceRecord) string { From 07cbfdbedec2804f2014c42049f23268b0dc2ec7 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Mon, 11 May 2026 14:31:38 +0200 Subject: [PATCH 16/27] [proxy] feature: bring your own proxy (#5627) --- .../reverseproxy/domain/manager/manager.go | 22 +- .../domain/manager/manager_test.go | 110 +++++ .../modules/reverseproxy/proxy/manager.go | 8 +- .../reverseproxy/proxy/manager/manager.go | 64 ++- .../proxy/manager/manager_test.go | 337 +++++++++++++++ .../reverseproxy/proxy/manager_mock.go | 79 +++- .../modules/reverseproxy/proxy/proxy.go | 12 +- .../reverseproxy/proxytoken/handler.go | 195 +++++++++ .../reverseproxy/proxytoken/handler_test.go | 275 ++++++++++++ .../modules/reverseproxy/service/interface.go | 2 + .../reverseproxy/service/interface_mock.go | 29 ++ .../reverseproxy/service/manager/api.go | 24 + .../reverseproxy/service/manager/manager.go | 20 +- .../service/manager/manager_test.go | 6 +- management/internals/server/boot.go | 2 +- management/internals/shared/grpc/proxy.go | 192 ++++++-- .../shared/grpc/proxy_address_test.go | 29 ++ .../internals/shared/grpc/proxy_auth.go | 3 - .../shared/grpc/proxy_group_access_test.go | 18 + .../internals/shared/grpc/proxy_test.go | 55 +++ .../shared/grpc/validate_session_test.go | 28 +- management/server/account_test.go | 2 +- management/server/http/handler.go | 4 + .../proxy/auth_callback_integration_test.go | 9 + .../testing/testing_tools/channel/channel.go | 4 +- management/server/store/sql_store.go | 121 +++++- management/server/store/store.go | 13 +- management/server/store/store_mock.go | 157 ++++++- proxy/management_byop_integration_test.go | 409 ++++++++++++++++++ proxy/management_integration_test.go | 34 +- shared/management/http/api/openapi.yml | 165 +++++++ shared/management/http/api/types.gen.go | 41 ++ 32 files changed, 2352 insertions(+), 117 deletions(-) create mode 100644 management/internals/modules/reverseproxy/domain/manager/manager_test.go create mode 100644 management/internals/modules/reverseproxy/proxy/manager/manager_test.go create mode 100644 management/internals/modules/reverseproxy/proxytoken/handler.go create mode 100644 management/internals/modules/reverseproxy/proxytoken/handler_test.go create mode 100644 management/internals/shared/grpc/proxy_address_test.go create mode 100644 proxy/management_byop_integration_test.go diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 2c4c1372e..ab899e0bf 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -31,6 +31,7 @@ type store interface { type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) + GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool @@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d var ret []*domain.Domain // Add connected proxy clusters as free domains. - // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + // For BYOP accounts, only their own cluster is returned; otherwise shared clusters. + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) return nil, err @@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName return nil, status.NewPermissionDeniedError() } - // Verify the target cluster is in the available clusters - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + // Verify the target cluster is in the available clusters for this account + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err) } @@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string { // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster. func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { - allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + allowList, err := m.getClusterAllowList(ctx, accountID) if err != nil { return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err) } @@ -298,6 +299,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) } +func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) { + byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("get BYOP cluster addresses: %w", err) + } + if len(byopAddresses) > 0 { + return byopAddresses, nil + } + return m.proxyManager.GetActiveClusterAddresses(ctx) +} + func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { bestCluster := "" bestLen := -1 diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go new file mode 100644 index 000000000..fdeb0765f --- /dev/null +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -0,0 +1,110 @@ +package manager + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockProxyManager struct { + getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error) +} + +func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { + if m.getActiveClusterAddressesFunc != nil { + return m.getActiveClusterAddressesFunc(ctx) + } + return nil, nil +} + +func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + if m.getActiveClusterAddressesForAccountFunc != nil { + return m.getActiveClusterAddressesForAccountFunc(ctx, accountID) + } + return nil, nil +} + +func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + +func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + +func TestGetClusterAllowList_BYOPProxy(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { + assert.Equal(t, "acc-123", accID) + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist") + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"byop.example.com"}, result) +} + +func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return nil, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result) +} + +func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return nil, errors.New("db error") + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails") + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "BYOP cluster addresses") +} + +func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"eu.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) +} + diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 53c52b3aa..07ea6f0ab 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,15 +11,19 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) + Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) Disconnect(ctx context.Context, proxyID, sessionID string) error Heartbeat(ctx context.Context, p *Proxy) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) - GetActiveClusters(ctx context.Context) ([]Cluster, error) + GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStale(ctx context.Context, inactivityDuration time.Duration) error + GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) + CountAccountProxies(ctx context.Context, accountID string) (int64, error) + IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error } // OIDCValidationConfig contains the OIDC configuration needed for token validation. diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index 341e8c943..b72a6ebe5 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -16,11 +16,16 @@ type store interface { DisconnectProxy(ctx context.Context, proxyID, sessionID string) error UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) - GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) + GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) + CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) + IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error } // Manager handles all proxy operations @@ -44,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { // Connect registers a new proxy connection in the database. // capabilities may be nil for old proxies that do not report them. -func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) { +func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) { now := time.Now() var caps proxy.Capabilities if capabilities != nil { @@ -55,9 +60,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress SessionID: sessionID, ClusterAddress: clusterAddress, IPAddress: ipAddress, + AccountID: accountID, LastSeen: now, ConnectedAt: &now, - Status: "connected", + Status: proxy.StatusConnected, Capabilities: caps, } @@ -77,7 +83,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress } // Disconnect marks a proxy as disconnected in the database. -func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error { +func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error { if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil { log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err) return err @@ -92,7 +98,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) erro } // Heartbeat updates the proxy's last seen timestamp. -func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { +func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil { log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err) return err @@ -104,7 +110,7 @@ func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { } // GetActiveClusterAddresses returns all unique cluster addresses for active proxies -func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { +func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { addresses, err := m.store.GetActiveProxyClusterAddresses(ctx) if err != nil { log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) @@ -113,16 +119,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error return addresses, nil } -// GetActiveClusters returns all active proxy clusters with their connected proxy count. -func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) { - clusters, err := m.store.GetActiveProxyClusters(ctx) - if err != nil { - log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err) - return nil, err - } - return clusters, nil -} - // ClusterSupportsCustomPorts returns whether any active proxy in the cluster // supports custom ports. Returns nil when no proxy has reported capabilities. func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { @@ -142,10 +138,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string } // CleanupStale removes proxies that haven't sent heartbeat in the specified duration -func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { +func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err) return err } return nil } + +func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err) + return nil, err + } + return addresses, nil +} + +func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) { + return m.store.GetProxyByAccountID(ctx, accountID) +} + +func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) { + return m.store.CountProxiesByAccountID(ctx, accountID) +} + +func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) { + conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID) + if err != nil { + return false, err + } + return !conflicting, nil +} + +func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil { + log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err) + return err + } + return nil +} + diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go new file mode 100644 index 000000000..3c53fe684 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go @@ -0,0 +1,337 @@ +package manager + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" +) + +type mockStore struct { + saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error + disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error + updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error + getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error) + cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error + getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error) + countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error) + isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error) + deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error +} + +func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { + if m.saveProxyFunc != nil { + return m.saveProxyFunc(ctx, p) + } + return nil +} +func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { + if m.disconnectProxyFunc != nil { + return m.disconnectProxyFunc(ctx, proxyID, sessionID) + } + return nil +} +func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { + if m.updateProxyHeartbeatFunc != nil { + return m.updateProxyHeartbeatFunc(ctx, p) + } + return nil +} +func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { + if m.getActiveProxyClusterAddressesFunc != nil { + return m.getActiveProxyClusterAddressesFunc(ctx) + } + return nil, nil +} +func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + if m.getActiveProxyClusterAddressesForAccFunc != nil { + return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID) + } + return nil, nil +} +func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) { + return nil, nil +} +func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error { + if m.cleanupStaleProxiesFunc != nil { + return m.cleanupStaleProxiesFunc(ctx, d) + } + return nil +} +func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + if m.getProxyByAccountIDFunc != nil { + return m.getProxyByAccountIDFunc(ctx, accountID) + } + return nil, fmt.Errorf("proxy not found for account %s", accountID) +} +func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + if m.countProxiesByAccountIDFunc != nil { + return m.countProxiesByAccountIDFunc(ctx, accountID) + } + return 0, nil +} +func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + if m.isClusterAddressConflictingFunc != nil { + return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID) + } + return false, nil +} +func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + if m.deleteAccountClusterFunc != nil { + return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID) + } + return nil +} +func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} +func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} +func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + +func newTestManager(s store) *Manager { + meter := noop.NewMeterProvider().Meter("test") + m, err := NewManager(s, meter) + if err != nil { + panic(err) + } + return m +} + +func TestConnect_WithAccountID(t *testing.T) { + accountID := "acc-123" + + var savedProxy *proxy.Proxy + s := &mockStore{ + saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error { + savedProxy = p + return nil + }, + } + + mgr := newTestManager(s) + _, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil) + require.NoError(t, err) + + require.NotNil(t, savedProxy) + assert.Equal(t, "proxy-1", savedProxy.ID) + assert.Equal(t, "session-1", savedProxy.SessionID) + assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress) + assert.Equal(t, "10.0.0.1", savedProxy.IPAddress) + assert.Equal(t, &accountID, savedProxy.AccountID) + assert.Equal(t, proxy.StatusConnected, savedProxy.Status) + assert.NotNil(t, savedProxy.ConnectedAt) +} + +func TestConnect_WithoutAccountID(t *testing.T) { + var savedProxy *proxy.Proxy + s := &mockStore{ + saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error { + savedProxy = p + return nil + }, + } + + mgr := newTestManager(s) + _, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil) + require.NoError(t, err) + + require.NotNil(t, savedProxy) + assert.Nil(t, savedProxy.AccountID) + assert.Equal(t, proxy.StatusConnected, savedProxy.Status) +} + +func TestConnect_StoreError(t *testing.T) { + s := &mockStore{ + saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error { + return errors.New("db error") + }, + } + + mgr := newTestManager(s) + _, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil) + assert.Error(t, err) +} + +func TestIsClusterAddressAvailable(t *testing.T) { + tests := []struct { + name string + conflicting bool + storeErr error + wantResult bool + wantErr bool + }{ + { + name: "available - no conflict", + conflicting: false, + wantResult: true, + }, + { + name: "not available - conflict exists", + conflicting: true, + wantResult: false, + }, + { + name: "store error", + storeErr: errors.New("db error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &mockStore{ + isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) { + return tt.conflicting, tt.storeErr + }, + } + + mgr := newTestManager(s) + result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123") + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantResult, result) + }) + } +} + +func TestCountAccountProxies(t *testing.T) { + tests := []struct { + name string + count int64 + storeErr error + wantCount int64 + wantErr bool + }{ + { + name: "no proxies", + count: 0, + wantCount: 0, + }, + { + name: "one proxy", + count: 1, + wantCount: 1, + }, + { + name: "store error", + storeErr: errors.New("db error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &mockStore{ + countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) { + return tt.count, tt.storeErr + }, + } + + mgr := newTestManager(s) + count, err := mgr.CountAccountProxies(context.Background(), "acc-123") + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantCount, count) + }) + } +} + +func TestGetAccountProxy(t *testing.T) { + accountID := "acc-123" + + t.Run("found", func(t *testing.T) { + expected := &proxy.Proxy{ + ID: "proxy-1", + ClusterAddress: "byop.example.com", + AccountID: &accountID, + Status: proxy.StatusConnected, + } + s := &mockStore{ + getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) { + assert.Equal(t, accountID, accID) + return expected, nil + }, + } + + mgr := newTestManager(s) + p, err := mgr.GetAccountProxy(context.Background(), accountID) + require.NoError(t, err) + assert.Equal(t, expected, p) + }) + + t.Run("not found", func(t *testing.T) { + s := &mockStore{ + getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) { + return nil, errors.New("not found") + }, + } + + mgr := newTestManager(s) + _, err := mgr.GetAccountProxy(context.Background(), accountID) + assert.Error(t, err) + }) +} + +func TestDeleteAccountCluster(t *testing.T) { + t.Run("success", func(t *testing.T) { + var deletedCluster, deletedAccount string + s := &mockStore{ + deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error { + deletedCluster = clusterAddress + deletedAccount = accountID + return nil + }, + } + + mgr := newTestManager(s) + err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123") + require.NoError(t, err) + assert.Equal(t, "cluster.example.com", deletedCluster) + assert.Equal(t, "acc-123", deletedAccount) + }) + + t.Run("store error", func(t *testing.T) { + s := &mockStore{ + deleteAccountClusterFunc: func(_ context.Context, _, _ string) error { + return errors.New("db error") + }, + } + + mgr := newTestManager(s) + err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123") + assert.Error(t, err) + }) +} + +func TestGetActiveClusterAddressesForAccount(t *testing.T) { + expected := []string{"byop.example.com"} + s := &mockStore{ + getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) { + assert.Equal(t, "acc-123", accID) + return expected, nil + }, + } + + mgr := newTestManager(s) + result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, expected, result) +} diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index 98d97b3c6..a0e360a1b 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -93,18 +93,18 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte } // Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) { +func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities) ret0, _ := ret[0].(*Proxy) ret1, _ := ret[1].(error) return ret0, ret1 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities) } // Disconnect mocks base method. @@ -136,19 +136,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx) } -// GetActiveClusters mocks base method. -func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) { +func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveClusters", ctx) - ret0, _ := ret[0].([]Cluster) + ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID) + ret0, _ := ret[0].([]string) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetActiveClusters indicates an expected call of GetActiveClusters. -func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID) } // Heartbeat mocks base method. @@ -165,6 +163,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p) } +// GetAccountProxy mocks base method. +func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID) + ret0, _ := ret[0].(*Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountProxy indicates an expected call of GetAccountProxy. +func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID) +} + +// CountAccountProxies mocks base method. +func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAccountProxies indicates an expected call of CountAccountProxies. +func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID) +} + +// IsClusterAddressAvailable mocks base method. +func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable. +func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID) +} + +// DeleteAccountCluster mocks base method. +func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID) +} + // MockController is a mock of Controller interface. type MockController struct { ctrl *gomock.Controller diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index dcedb8811..64394799e 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -1,6 +1,13 @@ package proxy -import "time" +import ( + "time" +) + +const ( + StatusConnected = "connected" + StatusDisconnected = "disconnected" +) // Capabilities describes what a proxy can handle, as reported via gRPC. // Nil fields mean the proxy never reported this capability. @@ -21,6 +28,7 @@ type Proxy struct { SessionID string `gorm:"type:varchar(36)"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` IPAddress string `gorm:"type:varchar(45)"` + AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` ConnectedAt *time.Time DisconnectedAt *time.Time @@ -36,6 +44,8 @@ func (Proxy) TableName() string { // Cluster represents a group of proxy nodes serving the same address. type Cluster struct { + ID string Address string ConnectedProxies int + SelfHosted bool } diff --git a/management/internals/modules/reverseproxy/proxytoken/handler.go b/management/internals/modules/reverseproxy/proxytoken/handler.go new file mode 100644 index 000000000..728cdf723 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxytoken/handler.go @@ -0,0 +1,195 @@ +package proxytoken + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/gorilla/mux" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +type handler struct { + store store.Store + permissionsManager permissions.Manager +} + +func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) { + h := &handler{store: s, permissionsManager: permissionsManager} + router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS") + router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS") +} + +func (h *handler) createToken(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + var req api.ProxyTokenRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + if req.Name == "" || len(req.Name) > 255 { + util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w) + return + } + + var expiresIn time.Duration + if req.ExpiresIn != nil { + if *req.ExpiresIn < 0 { + util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w) + return + } + if *req.ExpiresIn > 0 { + expiresIn = time.Duration(*req.ExpiresIn) * time.Second + } + } + + accountID := userAuth.AccountId + generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId) + if err != nil { + util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w) + return + } + + if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil { + util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w) + return + } + + resp := toProxyTokenCreatedResponse(generated) + util.WriteJSONObject(r.Context(), w, resp) +} + +func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId) + if err != nil { + util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w) + return + } + + resp := make([]api.ProxyToken, 0, len(tokens)) + for _, token := range tokens { + resp = append(resp, toProxyTokenResponse(token)) + } + + util.WriteJSONObject(r.Context(), w, resp) +} + +func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete) + if err != nil { + util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w) + return + } + if !ok { + util.WriteErrorResponse("permission denied", http.StatusForbidden, w) + return + } + + tokenID := mux.Vars(r)["tokenId"] + if tokenID == "" { + util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w) + return + } + + token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID) + if err != nil { + if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound { + util.WriteErrorResponse("token not found", http.StatusNotFound, w) + } else { + util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w) + } + return + } + + if token.AccountID == nil || *token.AccountID != userAuth.AccountId { + util.WriteErrorResponse("token not found", http.StatusNotFound, w) + return + } + + if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil { + util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken { + resp := api.ProxyToken{ + Id: token.ID, + Name: token.Name, + Revoked: token.Revoked, + } + if !token.CreatedAt.IsZero() { + resp.CreatedAt = token.CreatedAt + } + if token.ExpiresAt != nil { + resp.ExpiresAt = token.ExpiresAt + } + if token.LastUsed != nil { + resp.LastUsed = token.LastUsed + } + return resp +} + +func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated { + base := toProxyTokenResponse(&generated.ProxyAccessToken) + plainToken := string(generated.PlainToken) + return api.ProxyTokenCreated{ + Id: base.Id, + Name: base.Name, + CreatedAt: base.CreatedAt, + ExpiresAt: base.ExpiresAt, + LastUsed: base.LastUsed, + Revoked: base.Revoked, + PlainToken: plainToken, + } +} diff --git a/management/internals/modules/reverseproxy/proxytoken/handler_test.go b/management/internals/modules/reverseproxy/proxytoken/handler_test.go new file mode 100644 index 000000000..a28752909 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxytoken/handler_test.go @@ -0,0 +1,275 @@ +package proxytoken + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +func authContext(accountID, userID string) context.Context { + return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{ + AccountId: accountID, + UserId: userID, + }) +} + +func TestCreateToken_AccountScoped(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + var savedToken *types.ProxyAccessToken + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, token *types.ProxyAccessToken) error { + savedToken = token + return nil + }, + ) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + body := `{"name": "my-token"}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext(accountID, "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp api.ProxyTokenCreated + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + + assert.NotEmpty(t, resp.PlainToken) + assert.Equal(t, "my-token", resp.Name) + assert.False(t, resp.Revoked) + + require.NotNil(t, savedToken) + require.NotNil(t, savedToken.AccountID) + assert.Equal(t, accountID, *savedToken.AccountID) + assert.Equal(t, "user-1", savedToken.CreatedBy) +} + +func TestCreateToken_WithExpiration(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var savedToken *types.ProxyAccessToken + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, token *types.ProxyAccessToken) error { + savedToken = token + return nil + }, + ) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + body := `{"name": "expiring-token", "expires_in": 3600}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + require.NotNil(t, savedToken) + require.NotNil(t, savedToken.ExpiresAt) + assert.True(t, savedToken.ExpiresAt.After(time.Now())) +} + +func TestCreateToken_EmptyName(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + body := `{"name": ""}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestCreateToken_PermissionDenied(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil) + + h := &handler{ + permissionsManager: permsMgr, + } + + body := `{"name": "test"}` + req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body)) + req = req.WithContext(authContext("acc-123", "user-1")) + w := httptest.NewRecorder() + + h.createToken(w, req) + assert.Equal(t, http.StatusForbidden, w.Code) +} + +func TestListTokens(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + now := time.Now() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{ + {ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false}, + {ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true}, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil) + req = req.WithContext(authContext(accountID, "user-1")) + w := httptest.NewRecorder() + + h.listTokens(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + var resp []api.ProxyToken + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + require.Len(t, resp, 2) + assert.Equal(t, "tok-1", resp[0].Id) + assert.False(t, resp[0].Revoked) + assert.Equal(t, "tok-2", resp[1].Id) + assert.True(t, resp[1].Revoked) +} + +func TestRevokeToken_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + accountID := "acc-123" + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + Name: "test-token", + AccountID: &accountID, + }, nil) + mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext(accountID, "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestRevokeToken_WrongAccount(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + otherAccount := "acc-other" + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + AccountID: &otherAccount, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestRevokeToken_ManagementWideToken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{ + ID: "tok-1", + AccountID: nil, + }, nil) + + permsMgr := permissions.NewMockManager(ctrl) + permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil) + + h := &handler{ + store: mockStore, + permissionsManager: permsMgr, + } + + req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil) + req = req.WithContext(authContext("acc-123", "user-1")) + req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"}) + w := httptest.NewRecorder() + + h.revokeToken(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index a49cbea35..6a94aa32b 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -10,6 +10,7 @@ import ( type Manager interface { GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) + DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) @@ -28,4 +29,5 @@ type Manager interface { RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StartExposeReaper(ctx context.Context) + GetServiceByDomain(ctx context.Context, domain string) (*Service, error) } diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index cc5ccbb8e..83b2162ed 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -79,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID) } +// DeleteAccountCluster mocks base method. +func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress) +} + // DeleteService mocks base method. func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { m.ctrl.T.Helper() @@ -138,6 +152,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID) } +// GetServiceByDomain mocks base method. +func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain) + ret0, _ := ret[0].(*Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceByDomain indicates an expected call of GetServiceByDomain. +func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain) +} + // GetGlobalServices mocks base method. func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index cd81efa88..08272077c 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -35,6 +35,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma accesslogsmanager.RegisterEndpoints(router, accessLogsManager) router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS") @@ -195,10 +196,33 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { apiClusters := make([]api.ProxyCluster, 0, len(clusters)) for _, c := range clusters { apiClusters = append(apiClusters, api.ProxyCluster{ + Id: c.ID, Address: c.Address, ConnectedProxies: c.ConnectedProxies, + SelfHosted: c.SelfHosted, }) } util.WriteJSONObject(r.Context(), w, apiClusters) } + +func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + clusterAddress := mux.Vars(r)["clusterAddress"] + if clusterAddress == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w) + return + } + + if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index d03a8dc82..c866d8f75 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -122,7 +122,21 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin return nil, status.NewPermissionDeniedError() } - return m.store.GetActiveProxyClusters(ctx) + return m.store.GetActiveProxyClusters(ctx, accountID) +} + +// DeleteAccountCluster removes all proxy registrations for the given cluster address +// owned by the account. +func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID) } func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { @@ -986,6 +1000,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]* return services, nil } +func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) if err != nil { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 46e79f1e5..47b8b3865 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { t.Helper() tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t)) - srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) return srv } @@ -714,7 +714,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) @@ -1138,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) { tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index f2ab0a2c4..7c655f020 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer { - proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) + proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store()) s.AfterInit(func(s *BaseServer) { proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetProxyController(s.ServiceProxyController()) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 6763a3ba3..9e5027547 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "net" "net/http" "net/url" "os" @@ -50,6 +51,11 @@ type ProxyOIDCConfig struct { KeysLocation string } +// ProxyTokenChecker checks whether a proxy access token is still valid. +type ProxyTokenChecker interface { + IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) +} + // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -78,6 +84,9 @@ type ProxyServiceServer struct { // Store for one-time authentication tokens tokenStore *OneTimeTokenStore + // Checker for proxy access token validity + tokenChecker ProxyTokenChecker + // OIDC configuration for proxy authentication oidcConfig ProxyOIDCConfig @@ -123,6 +132,8 @@ type proxyConnection struct { proxyID string sessionID string address string + accountID *string + tokenID string capabilities *proto.ProxyCapabilities stream proto.ProxyService_GetMappingUpdateServer sendChan chan *proto.GetMappingUpdateResponse @@ -130,8 +141,19 @@ type proxyConnection struct { cancel context.CancelFunc } +func enforceAccountScope(ctx context.Context, requestAccountID string) error { + token := GetProxyTokenFromContext(ctx) + if token == nil || token.AccountID == nil { + return nil + } + if requestAccountID == "" || *token.AccountID != requestAccountID { + return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID) + } + return nil +} + // NewProxyServiceServer creates a new proxy service server. -func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { +func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer { ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ accessLogManager: accessLogMgr, @@ -141,6 +163,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, + tokenChecker: tokenChecker, snapshotBatchSize: snapshotBatchSizeFromEnv(), cancel: cancel, } @@ -200,6 +223,25 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest return status.Errorf(codes.InvalidArgument, "proxy address is invalid") } + var accountID *string + token := GetProxyTokenFromContext(ctx) + if token != nil && token.AccountID != nil { + accountID = token.AccountID + + available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID) + if err != nil { + return status.Errorf(codes.Internal, "check cluster address: %v", err) + } + if !available { + return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress) + } + } + + var tokenID string + if token != nil { + tokenID = token.ID + } + sessionID := uuid.NewString() if old, loaded := s.connectedProxies.Load(proxyID); loaded { @@ -217,6 +259,8 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest proxyID: proxyID, sessionID: sessionID, address: proxyAddress, + accountID: accountID, + tokenID: tokenID, capabilities: req.GetCapabilities(), stream: stream, sendChan: make(chan *proto.GetMappingUpdateResponse, 100), @@ -224,7 +268,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest cancel: cancel, } - // Register proxy in database with capabilities var caps *proxy.Capabilities if c := req.GetCapabilities(); c != nil { caps = &proxy.Capabilities{ @@ -233,10 +276,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest SupportsCrowdsec: c.SupportsCrowdsec, } } - proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) + proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps) if err != nil { - log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) cancel() + if accountID != nil { + return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err) + } + log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) return status.Errorf(codes.Internal, "register proxy in database: %v", err) } @@ -266,6 +312,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest "session_id": sessionID, "address": proxyAddress, "cluster_addr": proxyAddress, + "account_id": accountID, "total_proxies": len(s.GetConnectedProxies()), }).Info("Proxy registered in cluster") defer func() { @@ -286,7 +333,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) }() - go s.heartbeat(connCtx, proxyRecord) + go s.heartbeat(connCtx, conn, proxyRecord) select { case err := <-errChan: @@ -298,8 +345,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } } -// heartbeat updates the proxy's last_seen timestamp every minute -func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { +// heartbeat updates the proxy's last_seen timestamp every minute and +// disconnects the proxy if its access token has been revoked. +func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() @@ -309,6 +357,19 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { if err := s.proxyManager.Heartbeat(ctx, p); err != nil { log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err) } + + if conn.tokenID != "" && s.tokenChecker != nil { + valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID) + if err != nil { + log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err) + continue + } + if !valid { + log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID) + conn.cancel() + return + } + } case <-ctx.Done(): log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID) return @@ -316,8 +377,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { } } -// sendSnapshot sends the initial snapshot of services to the connecting proxy. -// Only entries matching the proxy's cluster address are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { if !isProxyAddressValid(conn.address) { return fmt.Errorf("proxy address is invalid") @@ -355,7 +414,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec } func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) { - services, err := s.serviceManager.GetGlobalServices(ctx) + var services []*rpservice.Service + var err error + if conn.accountID != nil { + services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID) + } else { + services, err = s.serviceManager.GetGlobalServices(ctx) + } if err != nil { return nil, fmt.Errorf("get services from store: %w", err) } @@ -380,8 +445,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * return mappings, nil } -// isProxyAddressValid validates a proxy address +// isProxyAddressValid validates a proxy address (domain name or IP address) func isProxyAddressValid(addr string) bool { + if addr == "" { + return false + } + if net.ParseIP(addr) != nil { + return true + } _, err := domain.ValidateDomains([]string{addr}) return err == nil } @@ -405,6 +476,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) { accessLog := req.GetLog() + if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil { + return nil, err + } + fields := log.Fields{ "service_id": accessLog.GetServiceId(), "account_id": accessLog.GetAccountId(), @@ -442,11 +517,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA // Management should call this when services are created/updated/removed. // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. +// BYOP proxies only receive updates for their own account's services. func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) { log.Debugf("Broadcasting service update to all connected proxy servers") + updateAccountIDs := make(map[string]struct{}) + for _, m := range update.Mapping { + if m.AccountId != "" { + updateAccountIDs[m.AccountId] = struct{}{} + } + } s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) - resp := s.perProxyMessage(update, conn.proxyID) + connUpdate := update + if conn.accountID != nil && len(updateAccountIDs) > 0 { + if _, ok := updateAccountIDs[*conn.accountID]; !ok { + return true + } + filtered := filterMappingsForAccount(update.Mapping, *conn.accountID) + if len(filtered) == 0 { + return true + } + connUpdate = &proto.GetMappingUpdateResponse{ + Mapping: filtered, + InitialSyncComplete: update.InitialSyncComplete, + } + } + resp := s.perProxyMessage(connUpdate, conn.proxyID) if resp == nil { log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID) conn.cancel() @@ -463,6 +559,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes }) } +// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect. +func (s *ProxyServiceServer) ForceDisconnect(proxyID string) { + if connVal, ok := s.connectedProxies.Load(proxyID); ok { + conn := connVal.(*proxyConnection) + conn.cancel() + s.connectedProxies.Delete(proxyID) + log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy") + } +} + +func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping { + var filtered []*proto.ProxyMapping + for _, m := range mappings { + if m.AccountId == accountID { + filtered = append(filtered, m) + } + } + return filtered +} + // GetConnectedProxies returns a list of connected proxy IDs func (s *ProxyServiceServer) GetConnectedProxies() []string { var proxies []string @@ -531,6 +647,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd continue } conn := connVal.(*proxyConnection) + if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId { + continue + } if !proxyAcceptsMapping(conn, update) { log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id) continue @@ -618,6 +737,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { } func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { log.WithContext(ctx).Debugf("failed to get service from store: %v", err) @@ -737,6 +860,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic // SendStatusUpdate handles status updates from proxy clients. func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + accountID := req.GetAccountId() serviceID := req.GetServiceId() protoStatus := req.GetStatus() @@ -807,6 +934,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status { // CreateProxyPeer handles proxy peer creation with one-time token authentication func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + serviceID := req.GetServiceId() accountID := req.GetAccountId() token := req.GetToken() @@ -861,6 +992,10 @@ func strPtr(s string) *string { } func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) { + if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { + return nil, err + } + redirectURL, err := url.Parse(req.GetRedirectUrl()) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) @@ -989,21 +1124,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL // GenerateSessionToken creates a signed session JWT for the given domain and user. func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { - // Find the service by domain to get its signing key - services, err := s.serviceManager.GetGlobalServices(ctx) + service, err := s.getServiceByDomain(ctx, domain) if err != nil { - return "", fmt.Errorf("get services: %w", err) - } - - var service *rpservice.Service - for _, svc := range services { - if svc.Domain == domain { - service = svc - break - } - } - if service == nil { - return "", fmt.Errorf("service not found for domain: %s", domain) + return "", fmt.Errorf("service not found for domain %s: %w", domain, err) } if service.SessionPrivateKey == "" { @@ -1101,6 +1224,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val }, nil } + if err := enforceAccountScope(ctx, service.AccountID); err != nil { + return nil, err + } + pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey) if err != nil { log.WithFields(log.Fields{ @@ -1184,18 +1311,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val } func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { - services, err := s.serviceManager.GetGlobalServices(ctx) - if err != nil { - return nil, fmt.Errorf("get services: %w", err) - } - - for _, service := range services { - if service.Domain == domain { - return service, nil - } - } - - return nil, fmt.Errorf("service not found for domain: %s", domain) + return s.serviceManager.GetServiceByDomain(ctx, domain) } func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error { diff --git a/management/internals/shared/grpc/proxy_address_test.go b/management/internals/shared/grpc/proxy_address_test.go new file mode 100644 index 000000000..824a57226 --- /dev/null +++ b/management/internals/shared/grpc/proxy_address_test.go @@ -0,0 +1,29 @@ +package grpc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsProxyAddressValid(t *testing.T) { + tests := []struct { + name string + addr string + valid bool + }{ + {name: "valid domain", addr: "eu.proxy.netbird.io", valid: true}, + {name: "valid subdomain", addr: "byop.proxy.example.com", valid: true}, + {name: "valid IPv4", addr: "10.0.0.1", valid: true}, + {name: "valid IPv4 public", addr: "203.0.113.10", valid: true}, + {name: "valid IPv6", addr: "::1", valid: true}, + {name: "valid IPv6 full", addr: "2001:db8::1", valid: true}, + {name: "empty string", addr: "", valid: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr)) + }) + } +} diff --git a/management/internals/shared/grpc/proxy_auth.go b/management/internals/shared/grpc/proxy_auth.go index dd593dfa0..9888e8eee 100644 --- a/management/internals/shared/grpc/proxy_auth.go +++ b/management/internals/shared/grpc/proxy_auth.go @@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types return nil, status.Errorf(codes.Unauthenticated, "invalid token") } - // TODO: Enforce AccountID scope for "bring your own proxy" feature. - // Currently tokens are management-wide; AccountID field is reserved for future use. - if !token.IsValid() { return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked") } diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 0fa9a0dc1..46dad5b56 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -53,6 +53,10 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, return nil } +func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error { return nil } @@ -91,6 +95,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} +func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) { + if m.err != nil { + return nil, m.err + } + for _, services := range m.proxiesByAccount { + for _, svc := range services { + if svc.Domain == domain { + return svc, nil + } + } + } + return nil, errors.New("service not found for domain: " + domain) +} + func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { return nil, nil } diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 5a7a457df..0379edc6d 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -12,9 +12,12 @@ import ( cachestore "github.com/eko/gocache/lib/v4/store" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -316,6 +319,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { assert.Contains(t, err.Error(), "invalid state format") } +func scopedCtx(accountID string) context.Context { + token := &types.ProxyAccessToken{ + ID: "token-1", + AccountID: &accountID, + } + return context.WithValue(context.Background(), ProxyTokenContextKey, token) +} + +func globalCtx() context.Context { + token := &types.ProxyAccessToken{ + ID: "token-global", + } + return context.WithValue(context.Background(), ProxyTokenContextKey, token) +} + +func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "acc-1") + assert.NoError(t, err) +} + +func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "acc-2") + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.PermissionDenied, st.Code()) +} + +func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) { + err := enforceAccountScope(scopedCtx("acc-1"), "") + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.PermissionDenied, st.Code()) +} + +func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) { + err := enforceAccountScope(globalCtx(), "acc-1") + assert.NoError(t, err) + + err = enforceAccountScope(globalCtx(), "acc-2") + assert.NoError(t, err) + + err = enforceAccountScope(globalCtx(), "") + assert.NoError(t, err) +} + +func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) { + err := enforceAccountScope(context.Background(), "acc-1") + assert.NoError(t, err) +} + func TestValidateState_RejectsInvalidHMAC(t *testing.T) { ctx := context.Background() pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index d1d7fc8b7..6cd95f988 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) - proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) + proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil) proxyService.SetServiceManager(serviceManager) createTestProxies(t, ctx, testStore) @@ -318,13 +318,17 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { return nil, nil } type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error { return nil } @@ -340,6 +344,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co return nil, nil } +func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) { + return nil, nil +} + func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) { return nil, nil } @@ -348,6 +356,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time return nil } +func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) { + return 0, nil +} + +func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) { + return true, nil +} + +func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error { + return nil +} + func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { return nil } diff --git a/management/server/account_test.go b/management/server/account_test.go index 6bb875f99..65b27df49 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3113,7 +3113,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) if err != nil { return nil, nil, err diff --git a/management/server/http/handler.go b/management/server/http/handler.go index b9ea605d3..1e2c710db 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -144,6 +145,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks if serviceManager != nil && reverseProxyDomainManager != nil { reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) } + + proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router) + // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index c99acab63..30d8aa0e7 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -216,6 +216,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { nil, usersManager, nil, + nil, ) proxyService.SetServiceManager(&testServiceManager{store: testStore}) @@ -389,6 +390,10 @@ func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) er return nil } +func (m *testServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { return nil } @@ -435,6 +440,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri func (m *testServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { return nil, nil } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 1a8b83c7e..3c4ea98d0 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -109,7 +109,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy manager: %v", err) } - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { @@ -238,7 +238,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin if err != nil { t.Fatalf("Failed to create proxy manager: %v", err) } - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 065a0d306..4c2f0be52 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4513,6 +4513,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e return nil } +func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var tokens []*types.ProxyAccessToken + result := tx.Where("account_id = ?", accountID).Find(&tokens) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error) + } + + return tokens, nil +} + +func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) { + token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID) + if err != nil { + return false, err + } + return token.IsValid(), nil +} + +func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var token types.ProxyAccessToken + result := tx.Take(&token, idQueryCondition, tokenID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "proxy access token not found") + } + return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error) + } + + return &token, nil +} + // MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token. func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error { result := s.db.Model(&types.ProxyAccessToken{}). @@ -5487,7 +5528,7 @@ func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID strin Model(&proxy.Proxy{}). Where("id = ? AND session_id = ?", proxyID, sessionID). Updates(map[string]any{ - "status": "disconnected", + "status": proxy.StatusDisconnected, "disconnected_at": now, "last_seen": now, }) @@ -5518,7 +5559,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err if result.RowsAffected == 0 { p.LastSeen = now p.ConnectedAt = &now - p.Status = "connected" + p.Status = proxy.StatusConnected if err := s.db.Create(p).Error; err != nil { log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err) } @@ -5527,13 +5568,15 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err return nil } -// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies +// GetActiveProxyClusterAddresses returns the unique cluster addresses of active +// shared proxies (those without an account scope). BYOP cluster addresses are +// excluded; use GetActiveProxyClusterAddressesForAccount to retrieve them. func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { var addresses []string result := s.db. Model(&proxy.Proxy{}). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). + Where("account_id IS NULL AND status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). Distinct("cluster_address"). Pluck("cluster_address", &addresses) @@ -5545,13 +5588,75 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string return addresses, nil } -// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count. -func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { +func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { + var addresses []string + + result := s.db. + Model(&proxy.Proxy{}). + Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). + Distinct("cluster_address"). + Pluck("cluster_address", &addresses) + + if result.Error != nil { + return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account") + } + + return addresses, nil +} + +func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + var p proxy.Proxy + result := s.db.Where("account_id = ?", accountID).Take(&p) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "proxy not found for account") + } + return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error) + } + return &p, nil +} + +func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + var count int64 + result := s.db.Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count) + if result.Error != nil { + return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error) + } + return count, nil +} + +func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + var count int64 + result := s.db. + Model(&proxy.Proxy{}). + Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID). + Count(&count) + if result.Error != nil { + return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error) + } + return count > 0, nil +} + +func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + result := s.db. + Where("cluster_address = ? AND account_id = ?", clusterAddress, accountID). + Delete(&proxy.Proxy{}) + if result.Error != nil { + return status.Errorf(status.Internal, "delete account cluster: %v", result.Error) + } + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "cluster not found") + } + return nil +} + +func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { var clusters []proxy.Cluster result := s.db.Model(&proxy.Proxy{}). - Select("cluster_address as address, COUNT(*) as connected_proxies"). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). + Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted"). + Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)", + proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID). Group("cluster_address"). Scan(&clusters) diff --git a/management/server/store/store.go b/management/server/store/store.go index db98bc644..aa601c33f 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -114,6 +114,9 @@ type Store interface { GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) + GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) + GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) + IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error RevokeProxyAccessToken(ctx context.Context, tokenID string) error MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error @@ -288,11 +291,16 @@ type Store interface { DisconnectProxy(ctx context.Context, proxyID, sessionID string) error UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) - GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) + GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) + CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) + IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) + DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) @@ -496,6 +504,9 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key") }, + func(db *gorm.DB) error { + return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique") + }, } } diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 6c2c9bbc3..9780c521e 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -166,20 +166,6 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) } -// GetClusterSupportsCrowdSec mocks base method. -func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec. -func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) -} - // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -238,6 +224,21 @@ func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID) } +// CountProxiesByAccountID mocks base method. +func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID. +func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID) +} + // CreateAccessLog mocks base method. func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error { m.ctrl.T.Helper() @@ -576,6 +577,20 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID) } +// DeleteAccountCluster mocks base method. +func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID) +} + // DeleteRoute mocks base method. func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error { m.ctrl.T.Helper() @@ -1302,19 +1317,34 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx) } -// GetActiveProxyClusters mocks base method. -func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { +// GetActiveProxyClusterAddressesForAccount mocks base method. +func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx) + ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveProxyClusterAddressesForAccount indicates an expected call of GetActiveProxyClusterAddressesForAccount. +func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID) +} + +// GetActiveProxyClusters mocks base method. +func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID) ret0, _ := ret[0].([]proxy.Cluster) ret1, _ := ret[1].(error) return ret0, ret1 } // GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters. -func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID) } // GetAllAccounts mocks base method. @@ -1390,6 +1420,20 @@ func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) } +// GetClusterSupportsCrowdSec mocks base method. +func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec. +func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) +} + // GetClusterSupportsCustomPorts mocks base method. func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { m.ctrl.T.Helper() @@ -1959,6 +2003,51 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken) } +// GetProxyAccessTokenByID mocks base method. +func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID) + ret0, _ := ret[0].(*types2.ProxyAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID. +func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID) +} + +// GetProxyAccessTokensByAccountID mocks base method. +func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*types2.ProxyAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID. +func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID) +} + +// GetProxyByAccountID mocks base method. +func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID) + ret0, _ := ret[0].(*proxy.Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyByAccountID indicates an expected call of GetProxyByAccountID. +func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID) +} + // GetResourceGroups mocks base method. func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) { m.ctrl.T.Helper() @@ -2391,6 +2480,21 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID) } +// IsClusterAddressConflicting mocks base method. +func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting. +func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID) +} + // IsPrimaryAccount mocks base method. func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { m.ctrl.T.Helper() @@ -2407,6 +2511,21 @@ func (mr *MockStoreMockRecorder) IsPrimaryAccount(ctx, accountID interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID) } +// IsProxyAccessTokenValid mocks base method. +func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid. +func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID) +} + // ListCustomDomains mocks base method. func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) { m.ctrl.T.Helper() diff --git a/proxy/management_byop_integration_test.go b/proxy/management_byop_integration_test.go new file mode 100644 index 000000000..c0fbe682a --- /dev/null +++ b/proxy/management_byop_integration_test.go @@ -0,0 +1,409 @@ +package proxy + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + grpcstatus "google.golang.org/grpc/status" + + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/management/proto" +) + +type byopTestSetup struct { + store store.Store + proxyService *nbgrpc.ProxyServiceServer + grpcServer *grpc.Server + grpcAddr string + cleanup func() + + accountA string + accountB string + accountAToken types.PlainProxyToken + accountBToken types.PlainProxyToken + accountACluster string + accountBCluster string +} + +func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup { + t.Helper() + ctx := context.Background() + + testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + + accountAID := "byop-account-a" + accountBID := "byop-account-b" + + for _, acc := range []*types.Account{ + {Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()}, + {Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()}, + } { + require.NoError(t, testStore.SaveAccount(ctx, acc)) + } + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + pubKey := base64.StdEncoding.EncodeToString(pub) + privKey := base64.StdEncoding.EncodeToString(priv) + + clusterA := "byop-a.proxy.test" + clusterB := "byop-b.proxy.test" + + services := []*service.Service{ + { + ID: "svc-a1", AccountID: accountAID, Name: "App A1", + Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}}, + }, + { + ID: "svc-a2", AccountID: accountAID, Name: "App A2", + Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}}, + }, + { + ID: "svc-b1", AccountID: accountBID, Name: "App B1", + Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true, + SessionPrivateKey: privKey, SessionPublicKey: pubKey, + Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}}, + }, + } + for _, svc := range services { + require.NoError(t, testStore.CreateService(ctx, svc)) + } + + tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a") + require.NoError(t, err) + require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken)) + + tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b") + require.NoError(t, err) + require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken)) + + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) + + meter := noop.NewMeterProvider().Meter("test") + realProxyManager, err := proxymanager.NewManager(testStore, meter) + require.NoError(t, err) + + oidcConfig := nbgrpc.ProxyOIDCConfig{ + Issuer: "https://fake-issuer.example.com", + ClientID: "test-client", + HMACKey: []byte("test-hmac-key"), + } + + usersManager := users.NewManager(testStore) + + proxyService := nbgrpc.NewProxyServiceServer( + &testAccessLogManager{}, + tokenStore, + pkceStore, + oidcConfig, + nil, + usersManager, + realProxyManager, + nil, + ) + + svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore} + proxyService.SetServiceManager(svcMgr) + + proxyController := &testProxyController{} + proxyService.SetProxyController(proxyController) + + _, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor)) + proto.RegisterProxyServiceServer(grpcServer, proxyService) + + go func() { + if err := grpcServer.Serve(lis); err != nil { + t.Logf("gRPC server error: %v", err) + } + }() + + return &byopTestSetup{ + store: testStore, + proxyService: proxyService, + grpcServer: grpcServer, + grpcAddr: lis.Addr().String(), + cleanup: func() { + grpcServer.GracefulStop() + authClose() + storeCleanup() + }, + accountA: accountAID, + accountB: accountBID, + accountAToken: tokenA.PlainToken, + accountBToken: tokenB.PlainToken, + accountACluster: clusterA, + accountBCluster: clusterB, + } +} + +func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context { + md := metadata.Pairs("authorization", "Bearer "+string(token)) + return metadata.NewOutgoingContext(ctx, md) +} + +func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping { + t.Helper() + var mappings []*proto.ProxyMapping + for { + msg, err := stream.Recv() + require.NoError(t, err) + mappings = append(mappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } + } + return mappings +} + +func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings := receiveBYOPMappings(t, stream) + + assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services") + for _, m := range mappings { + assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A") + t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId()) + } + + ids := map[string]bool{} + for _, m := range mappings { + ids[m.GetId()] = true + } + assert.True(t, ids["svc-a1"], "should contain svc-a1") + assert.True(t, ids["svc-a2"], "should contain svc-a2") + assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1") +} + +func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-b", + Version: "test-v1", + Address: setup.accountBCluster, + }) + require.NoError(t, err) + + mappings := receiveBYOPMappings(t, stream) + + assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service") + assert.Equal(t, "svc-b1", mappings[0].GetId()) + assert.Equal(t, setup.accountB, mappings[0].GetAccountId()) +} + +func TestIntegration_BYOPProxy_MultiplePerAccount(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel1() + + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a-first", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings1 := receiveBYOPMappings(t, stream1) + assert.Len(t, mappings1, 2, "first BYOP proxy should receive account A's 2 services") + + ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a-second", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + mappings2 := receiveBYOPMappings(t, stream2) + assert.Len(t, mappings2, 2, "second BYOP proxy from same account should also receive the 2 services") + for _, m := range mappings2 { + assert.Equal(t, setup.accountA, m.GetAccountId()) + } +} + +func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel1() + + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-a-cluster", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _ = receiveBYOPMappings(t, stream1) + + ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: "byop-proxy-b-conflict", + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + _, err = stream2.Recv() + require.Error(t, err) + + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists") + t.Logf("expected rejection: %s", st.Message()) +} + +func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + proxyID := "byop-proxy-reconnect" + + ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + firstMappings := receiveBYOPMappings(t, stream1) + cancel1() + + time.Sleep(200 * time.Millisecond) + + ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: setup.accountACluster, + }) + require.NoError(t, err) + + secondMappings := receiveBYOPMappings(t, stream2) + + assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings") + + firstIDs := map[string]bool{} + for _, m := range firstMappings { + firstIDs[m.GetId()] = true + } + for _, m := range secondMappings { + assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId()) + } +} + +func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) { + setup := setupBYOPIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "no-auth-proxy", + Version: "test-v1", + Address: "some.cluster.io", + }) + require.NoError(t, err) + + _, err = stream.Recv() + require.Error(t, err) + + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Unauthenticated, st.Code()) +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 99bbdad0c..9fd3d2ce9 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/base64" "errors" + "fmt" "net" "sync" "sync/atomic" @@ -140,6 +141,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { nil, usersManager, proxyManager, + nil, ) // Use store-backed service manager @@ -201,8 +203,8 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) { - return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil +func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) { + return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: nbproxy.StatusConnected}, nil } func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error { @@ -217,6 +219,10 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin return nil, nil } +func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) { + return nil, nil +} + func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) { return nil, nil } @@ -237,6 +243,22 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro return nil } +func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) { + return nil, fmt.Errorf("proxy not found for account %s", accountID) +} + +func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) { + return 0, nil +} + +func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) { + return true, nil +} + +func (m *testProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error { + return nil +} + // testProxyController is a mock implementation of rpservice.ProxyController for testing. type testProxyController struct{} @@ -290,6 +312,10 @@ func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID return nil } +func (m *storeBackedServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { return nil } @@ -336,6 +362,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _, func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {} +func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) { + return m.store.GetServiceByDomain(ctx, domain) +} + func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { return nil, nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 82fca0782..942f3aa45 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3355,10 +3355,64 @@ components: example: false required: - enabled + ProxyTokenRequest: + type: object + properties: + name: + type: string + description: Human-readable token name + example: "my-proxy-token" + expires_in: + type: integer + minimum: 0 + description: Token expiration in seconds (0 = never expires) + example: 0 + required: + - name + ProxyToken: + type: object + properties: + id: + type: string + name: + type: string + expires_at: + type: string + format: date-time + created_at: + type: string + format: date-time + last_used: + type: string + format: date-time + revoked: + type: boolean + required: + - id + - name + - created_at + - revoked + ProxyTokenCreated: + type: object + description: Returned on creation — plain_token is shown only once + allOf: + - $ref: '#/components/schemas/ProxyToken' + - type: object + properties: + plain_token: + type: string + description: The plain text token (shown only once) + example: "nbx_abc123..." + required: + - plain_token ProxyCluster: type: object description: A proxy cluster represents a group of proxy nodes serving the same address properties: + id: + type: string + description: Unique identifier of a proxy in this cluster + example: "chlfq4q5r8kc73b0qjpg" address: type: string description: Cluster address used for CNAME targets @@ -3367,9 +3421,15 @@ components: type: integer description: Number of proxy nodes connected in this cluster example: 3 + self_hosted: + type: boolean + description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner + example: false required: + - id - address - connected_proxies + - self_hosted ReverseProxyDomainType: type: string description: Type of Reverse Proxy Domain @@ -11375,6 +11435,111 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/clusters/{clusterAddress}: + delete: + summary: Delete a self-hosted proxy cluster + description: Removes all self-hosted (BYOP) proxy registrations for the given cluster address owned by the account. + tags: [ Services ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: clusterAddress + required: true + schema: + type: string + description: The address of the proxy cluster + responses: + '200': + description: Proxy cluster deleted successfully + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/proxy-tokens: + get: + summary: List Proxy Tokens + description: Returns all proxy access tokens for the account + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy tokens + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyToken' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Proxy Token + description: Generate an account-scoped proxy access token for self-hosted proxy registration + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ProxyTokenRequest' + responses: + '200': + description: Proxy token created (plain token shown once) + content: + application/json: + schema: + $ref: '#/components/schemas/ProxyTokenCreated' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/proxy-tokens/{tokenId}: + delete: + summary: Revoke a Proxy Token + description: Revoke an account-scoped proxy access token + tags: [ Self-Hosted Proxies ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: tokenId + required: true + schema: + type: string + description: The unique identifier of the proxy token + responses: + '200': + description: Token revoked + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/services: get: summary: List all Services diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 4b94ea01c..b3bb475a9 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -3785,11 +3785,49 @@ type ProxyAccessLogsResponse struct { // ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address type ProxyCluster struct { + // Id Unique identifier of a proxy in this cluster + Id string `json:"id"` + // Address Cluster address used for CNAME targets Address string `json:"address"` // ConnectedProxies Number of proxy nodes connected in this cluster ConnectedProxies int `json:"connected_proxies"` + + // SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner + SelfHosted bool `json:"self_hosted"` +} + +// ProxyToken defines model for ProxyToken. +type ProxyToken struct { + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Id string `json:"id"` + LastUsed *time.Time `json:"last_used,omitempty"` + Name string `json:"name"` + Revoked bool `json:"revoked"` +} + +// ProxyTokenCreated defines model for ProxyTokenCreated. +type ProxyTokenCreated struct { + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Id string `json:"id"` + LastUsed *time.Time `json:"last_used,omitempty"` + Name string `json:"name"` + + // PlainToken The plain text token (shown only once) + PlainToken string `json:"plain_token"` + Revoked bool `json:"revoked"` +} + +// ProxyTokenRequest defines model for ProxyTokenRequest. +type ProxyTokenRequest struct { + // ExpiresIn Token expiration in seconds (0 = never expires) + ExpiresIn *int `json:"expires_in,omitempty"` + + // Name Human-readable token name + Name string `json:"name"` } // Resource defines model for Resource. @@ -5160,6 +5198,9 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate // PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType. type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest +// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType. +type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest + // PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType. type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest From 946ce4c3da24126ae34cc3b482277923863cbddb Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 12 May 2026 00:48:21 +0900 Subject: [PATCH 17/27] [client] Fix --config flag default to point at profile path (#6122) --- client/cmd/root.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index 29d4328a1..0a0aa4197 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -143,7 +143,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output") - rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location") + rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", profilemanager.DefaultConfigPath, "Overrides the default profile file location") rootCmd.AddCommand(upCmd) rootCmd.AddCommand(downCmd) From 96672dd1f8b116d2c40339bb94c61ed812c508e0 Mon Sep 17 00:00:00 2001 From: Nicolas Frati Date: Tue, 12 May 2026 13:50:35 +0200 Subject: [PATCH 18/27] [management] chores: update dex version (#6124) * chores: update dex version * chore: update dex fork --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 5704887ce..7c1a95e79 100644 --- a/go.mod +++ b/go.mod @@ -341,8 +341,8 @@ replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801 replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 -replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2 +replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1 -replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2 +replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1 replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 diff --git a/go.sum b/go.sum index 42652169c..53789f49d 100644 --- a/go.sum +++ b/go.sum @@ -485,10 +485,10 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2 h1:AP7OM/JnTogod3rVcLsMuilSG94kWQCr3z6R4rfVXnc= -github.com/netbirdio/dex v0.244.1-0.20260415145816-a0c6b40ff9f2/go.mod h1:+trSlzHNmdJGvz0oLEyyiuaPstUeD7YO6B3Fx9nyziY= -github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2 h1:HEEGJPsVw7/p7SEL3HWP4vaInxHo8OJSEaOkHpUAk+M= -github.com/netbirdio/dex/api/v2 v2.0.0-20260415145816-a0c6b40ff9f2/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M= +github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1 h1:4TaYr9O4xX0D2kszeOLclTiCbA3eHq3xWV+9ILJbIYs= +github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1/go.mod h1:IHH+H8vK2GfqtIt5u/5OdPh18yk0oDHuj2vz5+Goetg= +github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1 h1:neE7z+FPUkldl3faK/Jt+hJK2L+1XfQ1W33TQhU9m88= +github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M= github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= From 1224d6e1eeb04d6d43ca05f3bdbdc697b0b7a182 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 12 May 2026 21:52:56 +0900 Subject: [PATCH 19/27] [client] Persist management URL and pre-shared key overrides on login (#6065) --- client/server/login_overrides_test.go | 93 +++++++++++++++++++++++++++ client/server/server.go | 33 +++++++++- 2 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 client/server/login_overrides_test.go diff --git a/client/server/login_overrides_test.go b/client/server/login_overrides_test.go new file mode 100644 index 000000000..c45557c59 --- /dev/null +++ b/client/server/login_overrides_test.go @@ -0,0 +1,93 @@ +package server + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/profilemanager" +) + +func TestPersistLoginOverrides(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + initialMgmtURL string + initialPSK string + newMgmtURL string + newPSK *string + wantMgmtURL string + wantPSK string + }{ + { + name: "persist new management URL", + initialMgmtURL: "https://old.example.com:33073", + newMgmtURL: "https://new.example.com:33073", + wantMgmtURL: "https://new.example.com:33073", + }, + { + name: "persist new pre-shared key", + initialMgmtURL: "https://existing.example.com:33073", + initialPSK: "old-key", + newPSK: strPtr("new-key"), + wantMgmtURL: "https://existing.example.com:33073", + wantPSK: "new-key", + }, + { + name: "persist both", + initialMgmtURL: "https://old.example.com:33073", + initialPSK: "old-key", + newMgmtURL: "https://new.example.com:33073", + newPSK: strPtr("new-key"), + wantMgmtURL: "https://new.example.com:33073", + wantPSK: "new-key", + }, + { + name: "no inputs preserves existing", + initialMgmtURL: "https://existing.example.com:33073", + initialPSK: "existing-key", + wantMgmtURL: "https://existing.example.com:33073", + wantPSK: "existing-key", + }, + { + name: "empty PSK pointer is ignored", + initialMgmtURL: "https://existing.example.com:33073", + initialPSK: "existing-key", + newPSK: strPtr(""), + wantMgmtURL: "https://existing.example.com:33073", + wantPSK: "existing-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origDefault := profilemanager.DefaultConfigPath + t.Cleanup(func() { profilemanager.DefaultConfigPath = origDefault }) + + dir := t.TempDir() + profilemanager.DefaultConfigPath = filepath.Join(dir, "default.json") + + seed := profilemanager.ConfigInput{ + ConfigPath: profilemanager.DefaultConfigPath, + ManagementURL: tt.initialMgmtURL, + } + if tt.initialPSK != "" { + seed.PreSharedKey = strPtr(tt.initialPSK) + } + _, err := profilemanager.UpdateOrCreateConfig(seed) + require.NoError(t, err, "seed config") + + activeProf := &profilemanager.ActiveProfileState{Name: "default"} + err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK) + require.NoError(t, err, "persistLoginOverrides") + + cfg, err := profilemanager.ReadConfig(profilemanager.DefaultConfigPath) + require.NoError(t, err, "read back config") + + require.Equal(t, tt.wantMgmtURL, cfg.ManagementURL.String(), "management URL") + require.Equal(t, tt.wantPSK, cfg.PreSharedKey, "pre-shared key") + }) + } +} diff --git a/client/server/server.go b/client/server/server.go index bc8de8f9f..397fb37e4 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -490,6 +490,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.mutex.Unlock() + if err := persistLoginOverrides(activeProf, msg.ManagementUrl, msg.OptionalPreSharedKey); err != nil { + log.Errorf("failed to persist login overrides: %v", err) + return nil, fmt.Errorf("persist login overrides: %w", err) + } + config, _, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get active profile config: %v", err) @@ -964,7 +969,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe return &proto.LogoutResponse{}, nil } -// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist +// getConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) { cfgPath, err := activeProf.FilePath() if err != nil { @@ -1766,3 +1771,29 @@ func sendTerminalNotification() error { return wallCmd.Wait() } + +// persistLoginOverrides writes management URL and pre-shared key from a LoginRequest to the +// active profile config so that subsequent reads pick them up. Empty/nil values are ignored. +func persistLoginOverrides(activeProf *profilemanager.ActiveProfileState, managementURL string, preSharedKey *string) error { + if preSharedKey != nil && *preSharedKey == "" { + preSharedKey = nil + } + if managementURL == "" && preSharedKey == nil { + return nil + } + + cfgPath, err := activeProf.FilePath() + if err != nil { + return fmt.Errorf("active profile file path: %w", err) + } + + input := profilemanager.ConfigInput{ + ConfigPath: cfgPath, + ManagementURL: managementURL, + PreSharedKey: preSharedKey, + } + if _, err := profilemanager.UpdateOrCreateConfig(input); err != nil { + return fmt.Errorf("update config: %w", err) + } + return nil +} From 9126a192ca3f6845b798c948e1b1bc05bb7db965 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 12 May 2026 22:05:53 +0900 Subject: [PATCH 20/27] [client] Set 0644 perms on SSH client config after os.CreateTemp (#6126) --- client/ssh/config/manager.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index b58bf2233..20695cb4d 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -252,6 +252,10 @@ func (m *Manager) writeSSHConfig(sshConfig string) error { return fmt.Errorf("write SSH config file %s: %w", tmpPath, err) } + if err := os.Chmod(tmpPath, 0644); err != nil { + return fmt.Errorf("chmod SSH config file %s: %w", tmpPath, err) + } + if err := os.Rename(tmpPath, sshConfigPath); err != nil { return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err) } From ab2a8794e7a41693fff303725c96399ca190e8ff Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 14 May 2026 12:30:42 +0200 Subject: [PATCH 21/27] [client] Add short flags for status command options (#6137) * [client] Add short flags for status command options * uppercase filters --- client/cmd/status.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/client/cmd/status.go b/client/cmd/status.go index dae30e854..103b3044a 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -43,16 +43,16 @@ func init() { ipsFilterMap = make(map[string]struct{}) prefixNamesFilterMap = make(map[string]struct{}) statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format") - statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format") - statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format") - statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") - statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer") + statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format") + statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format") + statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") + statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6") - statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1") - statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") - statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") - statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") - statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)") + statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1") + statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") + statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") + statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") + statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)") } func statusFunc(cmd *cobra.Command, args []string) error { From 77b479286e399660ef2bdcbe7983363946660574 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 14 May 2026 13:27:50 +0200 Subject: [PATCH 22/27] [management] fix offline statuses for public proxy clusters (#6133) --- .../reverseproxy/domain/manager/manager.go | 23 ++++++- .../domain/manager/manager_test.go | 60 ++++++++++++++++--- .../reverseproxy/service/manager/manager.go | 59 +++++++++--------- management/server/store/sql_store.go | 28 +++++++-- management/server/store/sql_store_test.go | 52 ++++++++++++++++ 5 files changed, 177 insertions(+), 45 deletions(-) diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index ab899e0bf..2790b5f20 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -304,10 +304,27 @@ func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]s if err != nil { return nil, fmt.Errorf("get BYOP cluster addresses: %w", err) } - if len(byopAddresses) > 0 { - return byopAddresses, nil + publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + return nil, fmt.Errorf("get public cluster addresses: %w", err) } - return m.proxyManager.GetActiveClusterAddresses(ctx) + seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses)) + merged := make([]string, 0, len(byopAddresses)+len(publicAddresses)) + for _, addr := range byopAddresses { + if _, ok := seen[addr]; ok { + continue + } + seen[addr] = struct{}{} + merged = append(merged, addr) + } + for _, addr := range publicAddresses { + if _, ok := seen[addr]; ok { + continue + } + seen[addr] = struct{}{} + merged = append(merged, addr) + } + return merged, nil } func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go index fdeb0765f..5e7bbfc36 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -40,22 +40,37 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) return nil } -func TestGetClusterAllowList_BYOPProxy(t *testing.T) { +func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { assert.Equal(t, "acc-123", accID) return []string{"byop.example.com"}, nil }, getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { - t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist") - return nil, nil + return []string{"eu.proxy.netbird.io"}, nil }, } mgr := Manager{proxyManager: pm} result, err := mgr.getClusterAllowList(context.Background(), "acc-123") require.NoError(t, err) - assert.Equal(t, []string{"byop.example.com"}, result) + assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result) +} + +func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"shared.example.com", "byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result) } func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) { @@ -79,10 +94,6 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { return nil, errors.New("db error") }, - getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { - t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails") - return nil, nil - }, } mgr := Manager{proxyManager: pm} @@ -92,6 +103,23 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "BYOP cluster addresses") } +func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return nil, errors.New("db error") + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "public cluster addresses") +} + func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { @@ -108,3 +136,19 @@ func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) } +func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"byop.example.com"}, result) +} + diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index c866d8f75..4a8598afb 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -306,6 +306,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { customPorts := m.clusterCustomPorts(ctx, svc) + if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil { + return err + } + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if svc.Domain != "" { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { @@ -321,10 +325,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc * return err } - if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { - return err - } - if err := transaction.CreateService(ctx, svc); err != nil { return fmt.Errorf("create service: %w", err) } @@ -435,6 +435,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { customPorts := m.clusterCustomPorts(ctx, svc) + if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil { + return err + } + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil { return err @@ -448,10 +452,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee return err } - if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { - return err - } - if err := transaction.CreateService(ctx, svc); err != nil { return fmt.Errorf("create service: %w", err) } @@ -552,10 +552,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se svcForCaps.ProxyCluster = effectiveCluster customPorts := m.clusterCustomPorts(ctx, &svcForCaps) + if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil { + return nil, err + } + + // Validate subdomain requirement *before* the transaction: the underlying + // capability lookup talks to the main DB pool, and SQLite's single-connection + // pool would self-deadlock if this ran while the tx already held the only + // connection. + if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil { + return nil, err + } + var updateInfo serviceUpdateInfo err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts) + return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster) }) return &updateInfo, err @@ -585,7 +597,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, return existing.ProxyCluster, nil } -func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error { +func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error { existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID) if err != nil { return err @@ -603,17 +615,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St updateInfo.domainChanged = existingService.Domain != service.Domain if updateInfo.domainChanged { - if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil { + if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil { return err } } else { service.ProxyCluster = existingService.ProxyCluster } - if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil { - return err - } - m.preserveExistingAuthSecrets(service, existingService) if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil { return err @@ -628,9 +636,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St if err := m.checkPortConflict(ctx, transaction, service); err != nil { return err } - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } if err := transaction.UpdateService(ctx, service); err != nil { return fmt.Errorf("update service: %w", err) } @@ -638,20 +643,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St return nil } -func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error { +// handleDomainChange validates the new domain is free inside the transaction +// and applies the pre-resolved cluster (computed outside the tx by +// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks +// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1) +// because the transaction already holds the only connection. +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil { return err } - - if m.clusterDeriver != nil { - newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) - if err != nil { - log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) - } else { - svc.ProxyCluster = newCluster - } + if effectiveCluster != "" { + svc.ProxyCluster = effectiveCluster } - return nil } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4c2f0be52..893ee2168 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "net/url" "os" "path/filepath" "runtime" @@ -2794,12 +2795,27 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe connStr = filepath.Join(dataDir, filePath) } - // Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows - if hasQuery { - connStr += "?" + query - } else if runtime.GOOS != "windows" { + // Compose query parameters. User-provided ?_busy_timeout (or its mattn alias + // ?_timeout) overrides our default; otherwise inject 30s so SQLite waits at + // most that long on a lock instead of blocking the only Go-side connection. + // mattn/go-sqlite3 applies PRAGMA from the DSN on every fresh connection, so + // the value survives ConnMaxIdleTime/ConnMaxLifetime recycling. cache=shared + // stays the default on non-Windows for the same reason as before. + parsed, _ := url.ParseQuery(query) + var defaults []string + if parsed.Get("_busy_timeout") == "" && parsed.Get("_timeout") == "" { + defaults = append(defaults, "_busy_timeout=30000") + } + if !hasQuery && runtime.GOOS != "windows" { // To avoid `The process cannot access the file because it is being used by another process` on Windows - connStr += "?cache=shared" + defaults = append(defaults, "cache=shared") + } + parts := defaults + if hasQuery { + parts = append(parts, query) + } + if len(parts) > 0 { + connStr += "?" + strings.Join(parts, "&") } db, err := gorm.Open(sqlite.Open(connStr), getGormConfig()) @@ -3402,7 +3418,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { - timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, s.transactionTimeout) defer cancel() startTime := time.Now() diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2819265c3..7515add62 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -4592,3 +4592,55 @@ func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) { require.NoError(t, err) assert.Equal(t, 0, len(remainingRecords)) } + +// TestNewSqliteStore_BusyTimeoutApplied opens a fresh SQLite store and verifies +// that the _busy_timeout DSN parameter took effect at the driver level. Without +// this, lock contention on the single SQLite connection waits indefinitely on +// the Go side and can be hidden behind the 5-minute transactionTimeout. +func TestNewSqliteStore_BusyTimeoutApplied(t *testing.T) { + dir := t.TempDir() + store, err := NewSqliteStore(context.Background(), dir, nil, true) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close(context.Background()) + }) + + sqlDB, err := store.db.DB() + require.NoError(t, err) + row := sqlDB.QueryRow("PRAGMA busy_timeout") + var busyTimeout int + require.NoError(t, row.Scan(&busyTimeout)) + assert.Equal(t, 30000, busyTimeout, "SQLite busy_timeout must be set via DSN so it survives connection recycling") +} + +// TestNewSqliteStore_BusyTimeoutRespectsUserOverride confirms that an operator +// passing _busy_timeout or its mattn alias _timeout via NB_STORE_ENGINE_SQLITE_FILE +// wins over our 30s default. This guards the DSN merge logic in NewSqliteStore. +func TestNewSqliteStore_BusyTimeoutRespectsUserOverride(t *testing.T) { + cases := []struct { + name string + envFile string + expected int + }{ + {name: "explicit _busy_timeout wins", envFile: "store.db?_busy_timeout=5000", expected: 5000}, + {name: "alias _timeout wins", envFile: "store.db?_timeout=7000", expected: 7000}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", tc.envFile) + dir := t.TempDir() + store, err := NewSqliteStore(context.Background(), dir, nil, true) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close(context.Background()) + }) + + sqlDB, err := store.db.DB() + require.NoError(t, err) + row := sqlDB.QueryRow("PRAGMA busy_timeout") + var busyTimeout int + require.NoError(t, row.Scan(&busyTimeout)) + assert.Equal(t, tc.expected, busyTimeout) + }) + } +} From ea9fab4396fc5513f7c62e3465dd361dc8bb9e91 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 May 2026 23:05:33 +0900 Subject: [PATCH 23/27] [management] Allocate and preserve IPv6 overlay addresses for embedded proxy peers (#6132) --- management/server/account.go | 12 +++++++++ management/server/peer.go | 17 +++++++----- management/server/types/account.go | 30 ++++++++++----------- management/server/types/group.go | 5 +++- management/server/types/ipv6_groups_test.go | 30 +++++++++++++++++++++ 5 files changed, 70 insertions(+), 24 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 364c0c37b..77a46a069 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2487,6 +2487,18 @@ func (am *DefaultAccountManager) buildIPv6AllowedPeers(ctx context.Context, tran allowedPeers[peerID] = struct{}{} } } + + // Embedded proxy peers sit outside regular group membership but must + // participate in any v6-enabled overlay to reach v6-only peers. + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + if err != nil { + return nil, fmt.Errorf("get peers: %w", err) + } + for _, p := range peers { + if p.ProxyMeta.Embedded { + allowedPeers[p.ID] = struct{}{} + } + } return allowedPeers, nil } diff --git a/management/server/peer.go b/management/server/peer.go index 8a39fbbb8..c3b130ba2 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -762,16 +762,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe newPeer.IP = freeIP if len(settings.IPv6EnabledGroups) > 0 && network.NetV6.IP != nil { - var allGroupID string - if !peer.ProxyMeta.Embedded { - allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, "All") - if err != nil { - log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err) - } else { + // Embedded proxy peers are not group members but participate in any + // IPv6-enabled overlay so reverse-proxy traffic reaches v6-only peers. + allocate := peer.ProxyMeta.Embedded + if !allocate { + var allGroupID string + if allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, types.GroupAllName); err == nil { allGroupID = allGroup.ID + } else { + log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err) } + allocate = peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) } - if peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) { + if allocate { v6Prefix, err := netip.ParsePrefix(network.NetV6.String()) if err != nil { return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err) diff --git a/management/server/types/account.go b/management/server/types/account.go index 49600163a..870333a60 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -598,28 +598,21 @@ func (a *Account) GetPeerGroups(peerID string) LookupMap { return groupList } -// PeerIPv6Allowed reports whether the given peer is in any of the account's IPv6 enabled groups. +// PeerIPv6Allowed reports whether the given peer participates in the IPv6 overlay. // Returns false if IPv6 is disabled or no groups are configured. func (a *Account) PeerIPv6Allowed(peerID string) bool { - if len(a.Settings.IPv6EnabledGroups) == 0 { - return false - } - - for _, groupID := range a.Settings.IPv6EnabledGroups { - group, ok := a.Groups[groupID] - if !ok { - continue - } - if slices.Contains(group.Peers, peerID) { - return true - } - } - return false + _, ok := a.peerIPv6AllowedSet()[peerID] + return ok } -// peerIPv6AllowedSet returns a set of peer IDs that belong to any IPv6-enabled group. +// peerIPv6AllowedSet returns the set of peer IDs that participate in the IPv6 overlay: +// members of any IPv6-enabled group, plus every embedded proxy peer (which sit outside +// regular group membership but must reach v6-enabled peers). func (a *Account) peerIPv6AllowedSet() map[string]struct{} { result := make(map[string]struct{}) + if len(a.Settings.IPv6EnabledGroups) == 0 { + return result + } for _, groupID := range a.Settings.IPv6EnabledGroups { group, ok := a.Groups[groupID] if !ok { @@ -629,6 +622,11 @@ func (a *Account) peerIPv6AllowedSet() map[string]struct{} { result[peerID] = struct{}{} } } + for id, p := range a.Peers { + if p != nil && p.ProxyMeta.Embedded { + result[id] = struct{}{} + } + } return result } diff --git a/management/server/types/group.go b/management/server/types/group.go index 00fdf7a69..b4f50080a 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -92,9 +92,12 @@ func (g *Group) HasPeers() bool { return len(g.Peers) > 0 } +// GroupAllName is the reserved name of the default group that contains every peer in an account. +const GroupAllName = "All" + // IsGroupAll checks if the group is a default "All" group. func (g *Group) IsGroupAll() bool { - return g.Name == "All" + return g.Name == GroupAllName } // AddPeer adds peerID to Peers if not present, returning true if added. diff --git a/management/server/types/ipv6_groups_test.go b/management/server/types/ipv6_groups_test.go index 5151e1b1f..766a9c92c 100644 --- a/management/server/types/ipv6_groups_test.go +++ b/management/server/types/ipv6_groups_test.go @@ -232,3 +232,33 @@ func TestIPv6RecalculationOnGroupChange(t *testing.T) { assert.True(t, account.PeerIPv6Allowed("peer3"), "peer3 now in infra") }) } + +func TestPeerIPv6AllowedEmbeddedProxy(t *testing.T) { + account := &Account{ + Peers: map[string]*nbpeer.Peer{ + "peer1": {ID: "peer1"}, + "proxy": {ID: "proxy", ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "netbird.test"}}, + }, + Groups: map[string]*Group{ + "group-devs": {ID: "group-devs", Peers: []string{"peer1"}}, + }, + Settings: &Settings{}, + } + + t.Run("embedded proxy allowed when any v6 group exists, without group membership", func(t *testing.T) { + account.Settings.IPv6EnabledGroups = []string{"group-devs"} + assert.True(t, account.PeerIPv6Allowed("proxy"), "embedded proxy participates in v6 overlay") + assert.True(t, account.PeerIPv6Allowed("peer1"), "regular peer in enabled group still allowed") + }) + + t.Run("embedded proxy denied when no v6 group enabled", func(t *testing.T) { + account.Settings.IPv6EnabledGroups = nil + assert.False(t, account.PeerIPv6Allowed("proxy"), "v6 disabled account-wide denies embedded proxies too") + }) + + t.Run("non-embedded peer outside any enabled group is not pulled in", func(t *testing.T) { + account.Settings.IPv6EnabledGroups = []string{"group-devs"} + account.Peers["lonely"] = &nbpeer.Peer{ID: "lonely"} + assert.False(t, account.PeerIPv6Allowed("lonely"), "embedded-proxy bypass must not leak to regular peers") + }) +} From 3f914090cbb345707a88b5edb20d9c1351873b4c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 May 2026 23:22:53 +0900 Subject: [PATCH 24/27] [client] Bracket IPv6 in embed listeners, expand debug bundle (#6134) --- client/embed/embed.go | 4 +- client/internal/debug/debug.go | 36 +++-- client/internal/debug/debug_linux.go | 195 +++++++++++++++++------- client/internal/debug/debug_nonlinux.go | 5 + 4 files changed, 178 insertions(+), 62 deletions(-) diff --git a/client/embed/embed.go b/client/embed/embed.go index 4b9445b97..8b669e547 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) { if err != nil { return nil, fmt.Errorf("split host port: %w", err) } - listenAddr := fmt.Sprintf("%s:%s", addr, port) + listenAddr := net.JoinHostPort(addr.String(), port) tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr) if err != nil { @@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) { if err != nil { return nil, fmt.Errorf("split host port: %w", err) } - listenAddr := fmt.Sprintf("%s:%s", addr, port) + listenAddr := net.JoinHostPort(addr.String(), port) udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) if err != nil { diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 9c50f02b3..ebaf71b21 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. -iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. -nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. +iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided. +ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided. +ipset.txt: Anonymized ipset list output, if --system-info flag was provided. +nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided. +sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only). resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided. scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. @@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird Other non-sensitive configuration options are included without anonymization. Firewall Rules (Linux only) -The bundle includes two separate firewall rule files: +The bundle includes the following firewall-related files: iptables.txt: -- Complete iptables ruleset with packet counters using 'iptables -v -n -L' +- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L' - Includes all tables (filter, nat, mangle, raw, security) - Shows packet and byte counters for each rule - All IP addresses are anonymized - Chain names, table names, and other non-sensitive information remain unchanged +ip6tables.txt: +- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L' +- Same table coverage and anonymization as iptables.txt +- Omitted when ip6tables is not installed or no IPv6 rules are present + +ipset.txt: +- Output of 'ipset list' (family-agnostic) +- IP addresses are anonymized; set names and types remain unchanged + nftables.txt: -- Complete nftables ruleset obtained via 'nft -a list ruleset' +- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset' - Includes rule handle numbers and packet counters -- All tables, chains, and rules are included -- Shows packet and byte counters for each rule -- All IP addresses are anonymized -- Chain names, table names, and other non-sensitive information remain unchanged +- All IP addresses are anonymized; chain/table names remain unchanged + +sysctls.txt: +- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify +- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf +- Interface names anonymized when --anonymize is set IP Rules (Linux only) The ip_rules.txt file contains detailed IP routing rule information: @@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() { log.Errorf("failed to add firewall rules to debug bundle: %v", err) } + if err := g.addSysctls(); err != nil { + log.Errorf("failed to add sysctls to debug bundle: %v", err) + } + if err := g.addDNSInfo(); err != nil { log.Errorf("failed to add DNS info to debug bundle: %v", err) } diff --git a/client/internal/debug/debug_linux.go b/client/internal/debug/debug_linux.go index aedf88b79..40d864eda 100644 --- a/client/internal/debug/debug_linux.go +++ b/client/internal/debug/debug_linux.go @@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) { // addFirewallRules collects and adds firewall rules to the archive func (g *BundleGenerator) addFirewallRules() error { log.Info("Collecting firewall rules") - iptablesRules, err := collectIPTablesRules() + g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt") + g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt") + + ipsetOutput, err := collectIPSets() if err != nil { - log.Warnf("Failed to collect iptables rules: %v", err) + log.Warnf("Failed to collect ipset information: %v", err) } else { if g.anonymize { - iptablesRules = g.anonymizer.AnonymizeString(iptablesRules) + ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput) } - if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil { - log.Warnf("Failed to add iptables rules to bundle: %v", err) + if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil { + log.Warnf("Failed to add ipset output to bundle: %v", err) } } @@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error { return nil } -// collectIPTablesRules collects rules using both iptables-save and verbose listing -func collectIPTablesRules() (string, error) { - var builder strings.Builder - - saveOutput, err := collectIPTablesSave() +// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle. +func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) { + rules, err := collectIPTablesRules(saveBin, listBin) if err != nil { - log.Warnf("Failed to collect iptables rules using iptables-save: %v", err) - } else { - builder.WriteString("=== iptables-save output ===\n") + log.Warnf("Failed to collect %s rules: %v", listBin, err) + return + } + if g.anonymize { + rules = g.anonymizer.AnonymizeString(rules) + } + if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil { + log.Warnf("Failed to add %s rules to bundle: %v", listBin, err) + } +} + +// collectIPTablesRules collects rules using both and verbose listing via . +// Returns an error when neither command produced any output (e.g. the binary is missing), +// so the caller can skip writing an empty file. +func collectIPTablesRules(saveBin, listBin string) (string, error) { + var builder strings.Builder + var collected bool + var firstErr error + + saveOutput, err := runCommand(saveBin) + switch { + case err != nil: + firstErr = err + log.Warnf("Failed to collect %s output: %v", saveBin, err) + case strings.TrimSpace(saveOutput) == "": + log.Debugf("%s produced no output, skipping", saveBin) + default: + builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin)) builder.WriteString(saveOutput) builder.WriteString("\n") + collected = true } - ipsetOutput, err := collectIPSets() - if err != nil { - log.Warnf("Failed to collect ipset information: %v", err) - } else { - builder.WriteString("=== ipset list output ===\n") - builder.WriteString(ipsetOutput) - builder.WriteString("\n") - } - - builder.WriteString("=== iptables -v -n -L output ===\n") + listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin) + builder.WriteString(listHeader) tables := []string{"filter", "nat", "mangle", "raw", "security"} - for _, table := range tables { - builder.WriteString(fmt.Sprintf("*%s\n", table)) - - stats, err := getTableStatistics(table) + stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table) if err != nil { - log.Warnf("Failed to get statistics for table %s: %v", table, err) + if firstErr == nil { + firstErr = err + } + log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err) continue } + builder.WriteString(fmt.Sprintf("*%s\n", table)) builder.WriteString(stats) builder.WriteString("\n") + collected = true } + if !collected { + return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr) + } return builder.String(), nil } @@ -214,34 +238,15 @@ func collectIPSets() (string, error) { return ipsets, nil } -// collectIPTablesSave uses iptables-save to get rule definitions -func collectIPTablesSave() (string, error) { - cmd := exec.Command("iptables-save") +// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure. +func runCommand(name string, args ...string) (string, error) { + cmd := exec.Command(name, args...) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { - return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String()) - } - - rules := stdout.String() - if strings.TrimSpace(rules) == "" { - return "", fmt.Errorf("no iptables rules found") - } - - return rules, nil -} - -// getTableStatistics gets verbose statistics for an entire table using iptables command -func getTableStatistics(table string) (string, error) { - cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String()) + return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String()) } return stdout.String(), nil @@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string { return fmt.Sprintf("type-%v", keyType) } } + +// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle. +func (g *BundleGenerator) addSysctls() error { + log.Info("Collecting sysctls") + content := collectSysctls() + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil { + return fmt.Errorf("add sysctls to bundle: %w", err) + } + return nil +} + +// collectSysctls reads every sysctl that the netbird client may modify, plus +// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic. +// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf. +func collectSysctls() string { + var builder strings.Builder + + writeSysctlGroup(&builder, "forwarding", []string{ + "net.ipv4.ip_forward", + "net.ipv6.conf.all.forwarding", + "net.ipv6.conf.default.forwarding", + }) + writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding")) + writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding")) + writeSysctlGroup(&builder, "rp_filter", append( + []string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"}, + listInterfaceSysctls("ipv4", "rp_filter")..., + )) + writeSysctlGroup(&builder, "src_valid_mark", append( + []string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"}, + listInterfaceSysctls("ipv4", "src_valid_mark")..., + )) + writeSysctlGroup(&builder, "conntrack", []string{ + "net.netfilter.nf_conntrack_acct", + "net.netfilter.nf_conntrack_tcp_loose", + }) + writeSysctlGroup(&builder, "tcp", []string{ + "net.ipv4.tcp_tw_reuse", + }) + + return builder.String() +} + +func writeSysctlGroup(builder *strings.Builder, title string, keys []string) { + builder.WriteString(fmt.Sprintf("=== %s ===\n", title)) + for _, key := range keys { + value, err := readSysctl(key) + if err != nil { + builder.WriteString(fmt.Sprintf("%s = \n", key, err)) + continue + } + builder.WriteString(fmt.Sprintf("%s = %s\n", key, value)) + } + builder.WriteString("\n") +} + +// listInterfaceSysctls returns net.ipvX.conf.. keys for every +// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default" +// (callers add those explicitly so they appear first). +func listInterfaceSysctls(family, leaf string) []string { + dir := fmt.Sprintf("/proc/sys/net/%s/conf", family) + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + var keys []string + for _, e := range entries { + name := e.Name() + if name == "all" || name == "default" { + continue + } + keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf)) + } + sort.Strings(keys) + return keys +} + +func readSysctl(key string) (string, error) { + path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/")) + value, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(value)), nil +} diff --git a/client/internal/debug/debug_nonlinux.go b/client/internal/debug/debug_nonlinux.go index ace53bd94..878fee40f 100644 --- a/client/internal/debug/debug_nonlinux.go +++ b/client/internal/debug/debug_nonlinux.go @@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error { // IP rules are only supported on Linux return nil } + +func (g *BundleGenerator) addSysctls() error { + // Sysctl collection is only supported on Linux + return nil +} From 07e5450117dd0451aaeefc18729a822115587e69 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 May 2026 23:42:40 +0900 Subject: [PATCH 25/27] [management] Bracket IPv6 reverse-proxy target hosts when building URL Host field (#6141) --- .../modules/reverseproxy/service/service.go | 18 ++++- .../reverseproxy/service/service_test.go | 77 +++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index 769e037bc..166a66a5f 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -381,13 +381,14 @@ func (s *Service) buildPathMappings() []*proto.PathMapping { } // HTTP/HTTPS: build full URL + hostNoBrackets := strings.TrimSuffix(strings.TrimPrefix(target.Host, "["), "]") targetURL := url.URL{ Scheme: target.Protocol, - Host: target.Host, + Host: bracketIPv6Host(hostNoBrackets), Path: "/", } if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { - targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10)) + targetURL.Host = net.JoinHostPort(hostNoBrackets, strconv.FormatUint(uint64(target.Port), 10)) } path := "/" @@ -405,6 +406,19 @@ func (s *Service) buildPathMappings() []*proto.PathMapping { return pathMappings } +// bracketIPv6Host wraps host in square brackets when it is an IPv6 literal, as +// required for the Host field of net/url.URL (RFC 3986 §3.2.2). v4-mapped IPv6 +// addresses are bracketed too since their textual form contains colons. +func bracketIPv6Host(host string) string { + if strings.HasPrefix(host, "[") { + return host + } + if addr, err := netip.ParseAddr(host); err == nil && addr.Is6() { + return "[" + host + "]" + } + return host +} + func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { switch op { case Create: diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index ff54cb79f..f1349ff65 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -351,6 +351,83 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) { port: 80, wantTarget: "https://10.0.0.1:80/", }, + { + name: "domain host without port is unchanged", + protocol: "http", + host: "example.com", + port: 0, + wantTarget: "http://example.com/", + }, + { + name: "domain host with non-default port is unchanged", + protocol: "http", + host: "example.com", + port: 8080, + wantTarget: "http://example.com:8080/", + }, + { + name: "ipv6 host without port is bracketed", + protocol: "http", + host: "fb00:cafe:1::3", + port: 0, + wantTarget: "http://[fb00:cafe:1::3]/", + }, + { + name: "ipv6 host with default port omits port and brackets host", + protocol: "http", + host: "fb00:cafe:1::3", + port: 80, + wantTarget: "http://[fb00:cafe:1::3]/", + }, + { + name: "ipv6 host with non-default port is bracketed", + protocol: "http", + host: "fb00:cafe:1::3", + port: 8080, + wantTarget: "http://[fb00:cafe:1::3]:8080/", + }, + { + name: "ipv6 loopback without port is bracketed", + protocol: "http", + host: "::1", + port: 0, + wantTarget: "http://[::1]/", + }, + { + name: "ipv6 host with 5-digit port is bracketed", + protocol: "http", + host: "fb00:cafe::1", + port: 18080, + wantTarget: "http://[fb00:cafe::1]:18080/", + }, + { + name: "pre-bracketed ipv6 without port stays single-bracketed", + protocol: "http", + host: "[fb00:cafe::1]", + port: 0, + wantTarget: "http://[fb00:cafe::1]/", + }, + { + name: "pre-bracketed ipv6 with port is not double-bracketed", + protocol: "http", + host: "[fb00:cafe::1]", + port: 8080, + wantTarget: "http://[fb00:cafe::1]:8080/", + }, + { + name: "v4-mapped ipv6 host without port is bracketed", + protocol: "http", + host: "::ffff:10.0.0.1", + port: 0, + wantTarget: "http://[::ffff:10.0.0.1]/", + }, + { + name: "full-form 8-group ipv6 without port is bracketed", + protocol: "http", + host: "fb00:cafe:1:0:0:0:0:3", + port: 0, + wantTarget: "http://[fb00:cafe:1:0:0:0:0:3]/", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 2ccae7ec479c106efb6d7a7edff4bb55affb2aa4 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 May 2026 23:58:47 +0900 Subject: [PATCH 26/27] [client] Mirror v4 exit selection onto v6 pair and honour SkipAutoApply per route (#6150) --- client/internal/routemanager/manager.go | 5 +- .../internal/routeselector/routeselector.go | 84 ++++++----- .../routeselector/routeselector_test.go | 131 ++++++++++++++++++ client/ui/network.go | 10 +- 4 files changed, 197 insertions(+), 33 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e5d9363ca..907f1f592 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -704,7 +704,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI } func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool { - return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR + if len(routes) == 0 { + return false + } + return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network) } func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) { diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 30afc013b..2ddc24bf2 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "slices" + "strings" "sync" "github.com/hashicorp/go-multierror" @@ -12,10 +13,6 @@ import ( "github.com/netbirdio/netbird/route" ) -const ( - exitNodeCIDR = "0.0.0.0/0" -) - type RouteSelector struct { mu sync.RWMutex deselectedRoutes map[route.NetID]struct{} @@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { rs.mu.RLock() defer rs.mu.RUnlock() - if rs.deselectAll { - return false - } - - _, deselected := rs.deselectedRoutes[routeID] - isSelected := !deselected - return isSelected + return rs.isSelectedLocked(routeID) } // FilterSelected removes unselected routes from the provided map. @@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { filtered := route.HAMap{} for id, rt := range routes { - netID := id.NetID() - _, deselected := rs.deselectedRoutes[netID] - if !deselected { + if !rs.isDeselectedLocked(id.NetID()) { filtered[id] = rt } } return filtered } -// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route +// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route. +// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of +// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing +// transparently apply to the synthesized v6 entry. func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool { rs.mu.RLock() defer rs.mu.RUnlock() - _, selected := rs.selectedRoutes[routeID] - _, deselected := rs.deselectedRoutes[routeID] - return selected || deselected + return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID)) } func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap { @@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap filtered := make(route.HAMap, len(routes)) for id, rt := range routes { netID := id.NetID() - if rs.isDeselected(netID) { + if rs.isDeselectedLocked(netID) { continue } @@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap return filtered } -func (rs *RouteSelector) isDeselected(netID route.NetID) bool { +// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit +// state of its own, so selections made on the v4 entry govern the v6 entry automatically. +// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route +// would make it inherit unrelated v4 state. Must be called with rs.mu held. +func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID { + name := string(id) + if !strings.HasSuffix(name, route.V6ExitSuffix) { + return id + } + if _, ok := rs.selectedRoutes[id]; ok { + return id + } + if _, ok := rs.deselectedRoutes[id]; ok { + return id + } + return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix)) +} + +func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool { + if rs.deselectAll { + return false + } + _, deselected := rs.deselectedRoutes[routeID] + return !deselected +} + +func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool { + if rs.deselectAll { + return true + } _, deselected := rs.deselectedRoutes[netID] - return deselected || rs.deselectAll + return deselected +} + +func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool { + _, selected := rs.selectedRoutes[routeID] + _, deselected := rs.deselectedRoutes[routeID] + return selected || deselected } func isExitNode(rt []*route.Route) bool { - return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR + return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network)) } func (rs *RouteSelector) applyExitNodeFilter( @@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter( rt []*route.Route, out route.HAMap, ) { - - if rs.hasUserSelections() { - // user made explicit selects/deselects - if rs.IsSelected(netID) { + // Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also + // drops the synthesized v6 entry that lacks its own explicit state. + effective := rs.effectiveNetID(netID) + if rs.hasUserSelectionForRouteLocked(effective) { + if rs.isSelectedLocked(effective) { out[id] = rt } return } - // no explicit selections: only include routes marked !SkipAutoApply (=AutoApply) + // no explicit selection for this route: defer to management's SkipAutoApply flag sel := collectSelected(rt) if len(sel) > 0 { out[id] = sel } } -func (rs *RouteSelector) hasUserSelections() bool { - return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0 -} - func collectSelected(rt []*route.Route) []*route.Route { var sel []*route.Route for _, r := range rt { diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index 5faea2456..3f0d9f120 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) { assert.Len(t, filtered, 0) // No routes should be selected } +// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection +// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute +// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its +// v4 base, so legacy persisted selections that predate v6 pairing transparently +// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected) +// stay literal so unrelated routes named "*-v6" don't inherit unrelated state. +func TestRouteSelector_V6ExitPairInherits(t *testing.T) { + all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"} + + t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + + assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection") + + // unrelated v6 with no v4 base touched is unaffected + assert.False(t, rs.HasUserSelectionForRoute("exit2-v6")) + }) + + t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all)) + + // A non-exit route literally named "corp-v6" must not inherit "corp"'s state + // via the mirror; the mirror only applies in exit-node code paths. + assert.False(t, rs.IsSelected("corp")) + assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state") + }) + + t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all)) + + v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")} + v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")} + routes := route.HAMap{ + "exit1|0.0.0.0/0": {v4Route}, + "exit1-v6|::/0": {v6Route}, + } + + filtered := rs.FilterSelectedExitNodes(routes) + assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0")) + assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base") + }) + + t.Run("non-v6-suffix routes unaffected", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + + // A route literally named "exit1-something" must not pair-resolve. + assert.False(t, rs.HasUserSelectionForRoute("exit1-something")) + }) + + t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + + v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")} + v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")} + routes := route.HAMap{ + "exit1|0.0.0.0/0": {v4Route}, + "exit1-v6|::/0": {v6Route}, + } + + filtered := rs.FilterSelectedExitNodes(routes) + assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair") + }) + + t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all)) + + // A non-default-route entry named "corp-v6" is not an exit node and + // must not be skipped because its v4 base "corp" is deselected. + corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")} + routes := route.HAMap{ + "corp-v6|10.0.0.0/8": {corpV6}, + } + + filtered := rs.FilterSelectedExitNodes(routes) + assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"), + "non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes") + }) +} + +// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's +// SkipAutoApply flag governs each untouched route independently, even when +// the user has explicit selections on other routes. +func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) { + autoApplied := &route.Route{ + NetID: "Auto", + Network: netip.MustParsePrefix("0.0.0.0/0"), + SkipAutoApply: false, + } + skipApply := &route.Route{ + NetID: "Skip", + Network: netip.MustParsePrefix("0.0.0.0/0"), + SkipAutoApply: true, + } + routes := route.HAMap{ + "Auto|0.0.0.0/0": {autoApplied}, + "Skip|0.0.0.0/0": {skipApply}, + } + + rs := routeselector.NewRouteSelector() + // User makes an unrelated explicit selection elsewhere. + require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"})) + + filtered := rs.FilterSelectedExitNodes(routes) + assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included") + assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection") +} + +// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized +// as exit nodes by the selector's filter path. +func TestRouteSelector_V6ExitIsExitNode(t *testing.T) { + v6Exit := &route.Route{ + NetID: "V6Only", + Network: netip.MustParsePrefix("::/0"), + SkipAutoApply: true, + } + routes := route.HAMap{ + "V6Only|::/0": {v6Exit}, + } + + rs := routeselector.NewRouteSelector() + filtered := rs.FilterSelectedExitNodes(routes) + assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply") +} + func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialRoutes := []route.NetID{"route1", "route2", "route3"} newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} diff --git a/client/ui/network.go b/client/ui/network.go index 1619f78a2..cd5d23558 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -193,7 +193,15 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { } func isDefaultRoute(routeRange string) bool { - return routeRange == "0.0.0.0/0" || routeRange == "::/0" + // routeRange is the merged display string from the daemon, e.g. "0.0.0.0/0", + // "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry. + for _, part := range strings.Split(routeRange, ",") { + switch strings.TrimSpace(part) { + case "0.0.0.0/0", "::/0": + return true + } + } + return false } func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { From 9ed2e2a5b463077f8abe3e3926695f5dc9411e29 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 16 May 2026 00:07:38 +0900 Subject: [PATCH 27/27] [client] Drop DNS probes for passive health projection (#5971) --- client/internal/connect.go | 2 - client/internal/dns/host.go | 12 + client/internal/dns/host_android.go | 19 +- client/internal/dns/host_ios.go | 9 + client/internal/dns/host_windows.go | 121 ++- client/internal/dns/hosts_dns_holder.go | 1 + client/internal/dns/local/local.go | 2 - client/internal/dns/mock_server.go | 9 +- client/internal/dns/network_manager_unix.go | 211 ++++- client/internal/dns/server.go | 928 ++++++++++++-------- client/internal/dns/server_android.go | 2 +- client/internal/dns/server_test.go | 698 +++++++++++++-- client/internal/dns/systemd_linux.go | 151 +++- client/internal/dns/upstream.go | 683 +++++++------- client/internal/dns/upstream_android.go | 5 +- client/internal/dns/upstream_general.go | 5 +- client/internal/dns/upstream_ios.go | 17 +- client/internal/dns/upstream_test.go | 227 +++-- client/internal/engine.go | 16 +- client/internal/routemanager/manager.go | 34 + client/internal/routemanager/mock.go | 9 + client/ios/NetBirdSDK/client.go | 6 +- 22 files changed, 2294 insertions(+), 873 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 8c0e9b1ba..ea884818f 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, - dnsAddresses []netip.AddrPort, stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. @@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, - HostDNSAddresses: dnsAddresses, StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, "") diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index f7dc46a6b..48eacef29 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -16,6 +16,10 @@ type hostManager interface { restoreHostDNS() error supportCustomPort() bool string() string + // getOriginalNameservers returns the OS-side resolvers used as PriorityFallback + // upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android, + // hardcoded Quad9 on iOS, nil for noop / mock. + getOriginalNameservers() []netip.Addr } type SystemDNSSettings struct { @@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool { func (n noopHostConfigurator) string() string { return "noop" } + +func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr { + return nil +} + +func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr { + return nil +} diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index dfa3e5712..48b3e0301 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -1,14 +1,20 @@ package dns import ( + "net/netip" + "github.com/netbirdio/netbird/client/internal/statemanager" ) +// androidHostManager is a noop on the OS side (Android's VPN service handles +// DNS for us) but tracks the OS-reported resolver list pushed via +// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source. type androidHostManager struct { + holder *hostsDNSHolder } -func newHostManager() (*androidHostManager, error) { - return &androidHostManager{}, nil +func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) { + return &androidHostManager{holder: holder}, nil } func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { @@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool { func (a androidHostManager) string() string { return "none" } + +func (a androidHostManager) getOriginalNameservers() []netip.Addr { + hosts := a.holder.get() + out := make([]netip.Addr, 0, len(hosts)) + for ap := range hosts { + out = append(out, ap.Addr()) + } + return out +} diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go index 1c0ac63e9..860bb8b50 100644 --- a/client/internal/dns/host_ios.go +++ b/client/internal/dns/host_ios.go @@ -3,6 +3,7 @@ package dns import ( "encoding/json" "fmt" + "net/netip" log "github.com/sirupsen/logrus" @@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) { }, nil } +func (a iosHostManager) getOriginalNameservers() []netip.Addr { + // Quad9 v4+v6: 9.9.9.9, 2620:fe::fe. + return []netip.Addr{ + netip.AddrFrom4([4]byte{9, 9, 9, 9}), + netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}), + } +} + func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error { jsonData, err := json.Marshal(config) if err != nil { diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 4a8cf8cec..4f6ece532 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -7,6 +7,7 @@ import ( "io" "net/netip" "os/exec" + "slices" "strings" "syscall" "time" @@ -44,9 +45,11 @@ const ( nrptMaxDomainsPerRule = 50 - interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` - interfaceConfigNameServerKey = "NameServer" - interfaceConfigSearchListKey = "SearchList" + interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` + interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces` + interfaceConfigNameServerKey = "NameServer" + interfaceConfigDhcpNameSrvKey = "DhcpNameServer" + interfaceConfigSearchListKey = "SearchList" // Network interface DNS registration settings disableDynamicUpdateKey = "DisableDynamicUpdate" @@ -67,10 +70,11 @@ const ( ) type registryConfigurator struct { - guid string - routingAll bool - gpo bool - nrptEntryCount int + guid string + routingAll bool + gpo bool + nrptEntryCount int + origNameservers []netip.Addr } func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { @@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { gpo: useGPO, } + origNameservers, err := configurator.captureOriginalNameservers() + switch { + case err != nil: + log.Warnf("capture original nameservers from non-WG adapters: %v", err) + case len(origNameservers) == 0: + log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty") + default: + log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers) + } + configurator.origNameservers = origNameservers + if err := configurator.configureInterface(); err != nil { log.Errorf("failed to configure interface settings: %v", err) } @@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { return configurator, nil } +// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface +// registry key except the WG adapter. v4 and v6 servers live in separate +// hives (Tcpip vs Tcpip6) keyed by the same interface GUID. +func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) { + seen := make(map[netip.Addr]struct{}) + var out []netip.Addr + var merr *multierror.Error + for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} { + addrs, err := r.captureFromTcpipRoot(root) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err)) + continue + } + for _, addr := range addrs { + if _, dup := seen[addr]; dup { + continue + } + seen[addr] = struct{}{} + out = append(out, addr) + } + } + return out, nberrors.FormatErrorOrNil(merr) +} + +func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) { + root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ) + if err != nil { + return nil, fmt.Errorf("open key: %w", err) + } + defer closer(root) + + guids, err := root.ReadSubKeyNames(-1) + if err != nil { + return nil, fmt.Errorf("read subkeys: %w", err) + } + + var out []netip.Addr + for _, guid := range guids { + if strings.EqualFold(guid, r.guid) { + continue + } + out = append(out, readInterfaceNameservers(rootPath, guid)...) + } + return out, nil +} + +func readInterfaceNameservers(rootPath, guid string) []netip.Addr { + keyPath := rootPath + "\\" + guid + k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE) + if err != nil { + return nil + } + defer closer(k) + + // Static NameServer wins over DhcpNameServer for actual resolution. + for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} { + raw, _, err := k.GetStringValue(name) + if err != nil || raw == "" { + continue + } + if out := parseRegistryNameservers(raw); len(out) > 0 { + return out + } + } + return nil +} + +func parseRegistryNameservers(raw string) []netip.Addr { + var out []netip.Addr + for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) { + addr, err := netip.ParseAddr(strings.TrimSpace(field)) + if err != nil { + continue + } + addr = addr.Unmap() + if !addr.IsValid() || addr.IsUnspecified() { + continue + } + // Drop unzoned link-local: not routable without a scope id. If + // the user wrote "fe80::1%eth0" ParseAddr preserves the zone. + if addr.IsLinkLocalUnicast() && addr.Zone() == "" { + continue + } + out = append(out, addr) + } + return out +} + +func (r *registryConfigurator) getOriginalNameservers() []netip.Addr { + return slices.Clone(r.origNameservers) +} + func (r *registryConfigurator) supportCustomPort() bool { return false } diff --git a/client/internal/dns/hosts_dns_holder.go b/client/internal/dns/hosts_dns_holder.go index 980d917a7..9ecc397be 100644 --- a/client/internal/dns/hosts_dns_holder.go +++ b/client/internal/dns/hosts_dns_holder.go @@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) { h.mutex.Unlock() } +//nolint:unused func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} { h.mutex.RLock() l := h.unprotectedDNSList diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index e9d310f00..4a75a76b6 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID { return "local-resolver" } -func (d *Resolver) ProbeAvailability(context.Context) {} - // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { logger := log.WithFields(log.Fields{ diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 548b1f54f..31fedd9e5 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -9,6 +9,7 @@ import ( dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string { return make([]string, 0) } -// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface -func (m *MockServer) ProbeAvailability() { -} - func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { if m.UpdateServerConfigFunc != nil { return m.UpdateServerConfigFunc(domains) @@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { return nil } -// SetRouteChecker mock implementation of SetRouteChecker from Server interface -func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) { +// SetRouteSources mock implementation of SetRouteSources from Server interface +func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) { // Mock implementation - no-op } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 66d82dcd7..3932e78b7 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "net/netip" + "slices" "strings" "time" @@ -32,6 +33,15 @@ const ( networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection" networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply" networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete" + networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config" + networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config" + networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface" + networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices" + networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config" + networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config" + networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData" + networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers" + networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers" networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0 networkManagerDbusIPv4Key = "ipv4" networkManagerDbusIPv6Key = "ipv6" @@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{ } type networkManagerDbusConfigurator struct { - dbusLinkObject dbus.ObjectPath - routingAll bool - ifaceName string + dbusLinkObject dbus.ObjectPath + routingAll bool + ifaceName string + origNameservers []netip.Addr } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface) - return &networkManagerDbusConfigurator{ + c := &networkManagerDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), ifaceName: wgInterface, - }, nil + } + + origNameservers, err := c.captureOriginalNameservers() + switch { + case err != nil: + log.Warnf("capture original nameservers from NetworkManager: %v", err) + case len(origNameservers) == 0: + log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty") + default: + log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers) + } + c.origNameservers = origNameservers + return c, nil +} + +// captureOriginalNameservers reads DNS servers from every NM device's +// IP4Config / IP6Config except our WG device. +func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) { + devices, err := networkManagerListDevices() + if err != nil { + return nil, fmt.Errorf("list devices: %w", err) + } + + seen := make(map[netip.Addr]struct{}) + var out []netip.Addr + for _, dev := range devices { + if dev == n.dbusLinkObject { + continue + } + ifaceName := readNetworkManagerDeviceInterface(dev) + for _, addr := range readNetworkManagerDeviceDNS(dev) { + addr = addr.Unmap() + if !addr.IsValid() || addr.IsUnspecified() { + continue + } + // IP6Config.Nameservers is a byte slice without zone info; + // reattach the device's interface name so a captured fe80::… + // stays routable. + if addr.IsLinkLocalUnicast() && ifaceName != "" { + addr = addr.WithZone(ifaceName) + } + if _, dup := seen[addr]; dup { + continue + } + seen[addr] = struct{}{} + out = append(out, addr) + } + } + return out, nil +} + +func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string { + obj, closeConn, err := getDbusObject(networkManagerDest, devicePath) + if err != nil { + return "" + } + defer closeConn() + v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty) + if err != nil { + return "" + } + s, _ := v.Value().(string) + return s +} + +func networkManagerListDevices() ([]dbus.ObjectPath, error) { + obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) + if err != nil { + return nil, fmt.Errorf("dbus NetworkManager: %w", err) + } + defer closeConn() + var devs []dbus.ObjectPath + if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil { + return nil, err + } + return devs, nil +} + +func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(networkManagerDest, devicePath) + if err != nil { + return nil + } + defer closeConn() + + var out []netip.Addr + if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" { + out = append(out, readIPv4ConfigDNS(path)...) + } + if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" { + out = append(out, readIPv6ConfigDNS(path)...) + } + return out +} + +func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath { + v, err := obj.GetProperty(property) + if err != nil { + return "" + } + path, ok := v.Value().(dbus.ObjectPath) + if !ok || path == "/" { + return "" + } + return path +} + +func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(networkManagerDest, path) + if err != nil { + return nil + } + defer closeConn() + + // NameserverData (NM 1.13+) carries strings; older NMs only expose the + // legacy uint32 Nameservers property. + if out := readIPv4NameserverData(obj); len(out) > 0 { + return out + } + return readIPv4LegacyNameservers(obj) +} + +func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr { + v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty) + if err != nil { + return nil + } + entries, ok := v.Value().([]map[string]dbus.Variant) + if !ok { + return nil + } + var out []netip.Addr + for _, entry := range entries { + addrVar, ok := entry["address"] + if !ok { + continue + } + s, ok := addrVar.Value().(string) + if !ok { + continue + } + if a, err := netip.ParseAddr(s); err == nil { + out = append(out, a) + } + } + return out +} + +func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr { + v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty) + if err != nil { + return nil + } + raw, ok := v.Value().([]uint32) + if !ok { + return nil + } + out := make([]netip.Addr, 0, len(raw)) + for _, n := range raw { + var b [4]byte + binary.LittleEndian.PutUint32(b[:], n) + out = append(out, netip.AddrFrom4(b)) + } + return out +} + +func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(networkManagerDest, path) + if err != nil { + return nil + } + defer closeConn() + v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty) + if err != nil { + return nil + } + raw, ok := v.Value().([][]byte) + if !ok { + return nil + } + out := make([]netip.Addr, 0, len(raw)) + for _, b := range raw { + if a, ok := netip.AddrFromSlice(b); ok { + out = append(out, a) + } + } + return out +} + +func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr { + return slices.Clone(n.origNameservers) } func (n *networkManagerDbusConfigurator) supportCustomPort() bool { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6fe2e21b6..e689f3586 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -6,11 +6,10 @@ import ( "fmt" "net/netip" "net/url" - "os" - "runtime" - "strconv" + "slices" "strings" "sync" + "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -25,11 +24,31 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -const envSkipDNSProbe = "NB_SKIP_DNS_PROBE" +const ( + // healthLookback must exceed the upstream query timeout so one + // query per refresh cycle is enough to keep a group marked healthy. + healthLookback = 60 * time.Second + nsGroupHealthRefreshInterval = 10 * time.Second + // defaultWarningDelayBase is the starting grace window before a + // "Nameserver group unreachable" event fires for a group that's + // never been healthy and only has overlay upstreams with no + // Connected peer. Per-server and overridable; see warningDelayFor. + defaultWarningDelayBase = 30 * time.Second + // warningDelayBonusCap caps the route-count bonus added to the + // base grace window. See warningDelayFor. + warningDelayBonusCap = 30 * time.Second +) + +// errNoUsableNameservers signals that a merged-domain group has no usable +// upstream servers. Callers should skip the group without treating it as a +// build failure. +var errNoUsableNameservers = errors.New("no usable nameservers") // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes type ReadyListener interface { @@ -54,10 +73,9 @@ type Server interface { UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string - ProbeAvailability() UpdateServerConfig(domains dnsconfig.ServerDomains) error PopulateManagementDomain(mgmtURL *url.URL) error - SetRouteChecker(func(netip.Addr) bool) + SetRouteSources(selected, active func() route.HAMap) SetFirewall(Firewall) } @@ -66,12 +84,47 @@ type nsGroupsByDomain struct { groups []*nbdns.NameServerGroup } -// hostManagerWithOriginalNS extends the basic hostManager interface -type hostManagerWithOriginalNS interface { - hostManager - getOriginalNameservers() []netip.Addr +// nsGroupID identifies a nameserver group by the tuple (server list, domain +// list) so config updates produce stable IDs across recomputations. +type nsGroupID string + +// nsHealthSnapshot is the input to projectNSGroupHealth, captured under +// s.mux so projection runs lock-free. +type nsHealthSnapshot struct { + groups []*nbdns.NameServerGroup + merged map[netip.AddrPort]UpstreamHealth + selected route.HAMap + active route.HAMap } +// nsGroupProj holds per-group state for the emission rules. +type nsGroupProj struct { + // unhealthySince is the start of the current Unhealthy streak, + // zero when the group is not currently Unhealthy. + unhealthySince time.Time + // everHealthy is sticky: once the group has been Healthy at least + // once this session, subsequent failures skip warningDelay. + everHealthy bool + // warningActive tracks whether we've already published a warning + // for the current streak, so recovery emits iff a warning did. + warningActive bool +} + +// nsGroupVerdict is the outcome of evaluateNSGroupHealth. +type nsGroupVerdict int + +const ( + // nsVerdictUndecided means no upstream has a fresh observation + // (startup before first query, or records aged past healthLookback). + nsVerdictUndecided nsGroupVerdict = iota + // nsVerdictHealthy means at least one upstream's most-recent + // in-lookback observation is a success. + nsVerdictHealthy + // nsVerdictUnhealthy means at least one upstream has a recent + // failure and none has a fresher success. + nsVerdictUnhealthy +) + // DefaultServer dns server object type DefaultServer struct { ctx context.Context @@ -100,26 +153,46 @@ type DefaultServer struct { permanent bool hostsDNSHolder *hostsDNSHolder + // fallbackHandler is the upstream resolver currently registered at + // PriorityFallback. Tracked so registerFallback can Stop() the previous + // instance instead of leaking its context. + fallbackHandler handlerWithStop + // make sense on mobile only searchDomainNotifier *notifier iosDnsManager IosDnsManager statusRecorder *peer.Status stateManager *statemanager.Manager - routeMatch func(netip.Addr) bool + // selectedRoutes returns admin-enabled client routes. + selectedRoutes func() route.HAMap + // activeRoutes returns the subset whose peer is in StatusConnected. + activeRoutes func() route.HAMap - probeMu sync.Mutex - probeCancel context.CancelFunc - probeWg sync.WaitGroup + nsGroups []*nbdns.NameServerGroup + healthProjectMu sync.Mutex + // nsGroupProj is the per-group state used by the emission rules. + // Accessed only under healthProjectMu. + nsGroupProj map[nsGroupID]*nsGroupProj + // warningDelayBase is the base grace window for health projection. + // Set at construction, mutated only by tests. Read by the + // refresher goroutine so never change it while one is running. + warningDelayBase time.Duration + // healthRefresh is buffered=1; writers coalesce, senders never block. + // See refreshHealth for the lock-order rationale. + healthRefresh chan struct{} } type handlerWithStop interface { dns.Handler Stop() - ProbeAvailability(context.Context) ID() types.HandlerID } +type upstreamHealthReporter interface { + UpstreamHealth() map[netip.AddrPort]UpstreamHealth +} + type handlerWrapper struct { domain string handler handlerWithStop @@ -174,7 +247,6 @@ func NewDefaultServerPermanentUpstream( ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true - ds.addHostRootZone() ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort()) ds.searchDomainNotifier = newNotifier(ds.SearchDomains()) ds.searchDomainNotifier.setListener(listener) @@ -182,21 +254,17 @@ func NewDefaultServerPermanentUpstream( return ds } -// NewDefaultServerIos returns a new dns server. It optimized for ios +// NewDefaultServerIos returns a new dns server. It optimized for ios. func NewDefaultServerIos( ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager, - hostsDnsList []netip.AddrPort, statusRecorder *peer.Status, disableSys bool, ) *DefaultServer { - log.Debugf("iOS host dns address list is: %v", hostsDnsList) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.iosDnsManager = iosDnsManager - ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true - ds.addHostRootZone() return ds } @@ -230,6 +298,8 @@ func newDefaultServer( hostManager: &noopHostConfigurator{}, mgmtCacheResolver: mgmtCacheResolver, currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied + warningDelayBase: defaultWarningDelayBase, + healthRefresh: make(chan struct{}, 1), } // register with root zone, handler chain takes care of the routing @@ -238,12 +308,26 @@ func newDefaultServer( return defaultServer } -// SetRouteChecker sets the function used by upstream resolvers to determine -// whether an IP is routed through the tunnel. -func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) { +// SetRouteSources wires the route-manager accessors used by health +// projection to classify each upstream for emission timing. +func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) { s.mux.Lock() defer s.mux.Unlock() - s.routeMatch = f + s.selectedRoutes = selected + s.activeRoutes = active + + // Permanent / iOS constructors build the root handler before the + // engine wires route sources, so its selectedRoutes callback would + // otherwise remain nil and overlay upstreams would be classified + // as public. Propagate the new accessors to existing handlers. + type routeSettable interface { + setSelectedRoutes(func() route.HAMap) + } + for _, entry := range s.dnsMuxMap { + if h, ok := entry.handler.(routeSettable); ok { + h.setSelectedRoutes(selected) + } + } } // RegisterHandler registers a handler for the given domains with the given priority. @@ -256,7 +340,6 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler // TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain for _, domain := range domains { - // convert to zone with simple ref counter s.extraDomains[toZone(domain)]++ } if !s.batchMode { @@ -357,6 +440,8 @@ func (s *DefaultServer) Initialize() (err error) { s.stateManager.RegisterState(&ShutdownState{}) + s.startHealthRefresher() + // Keep using noop host manager if dns off requested or running in netstack mode. // Netstack mode currently doesn't have a way to receive DNS requests. // TODO: Use listener on localhost in netstack mode when running as root. @@ -370,6 +455,13 @@ func (s *DefaultServer) Initialize() (err error) { return fmt.Errorf("initialize: %w", err) } s.hostManager = hostManager + // On mobile-permanent setups the seeded host DNS list is the only + // source until the first network-map arrives; register it now so DNS + // works in that window. Desktop host managers register fallback when + // applyConfiguration runs. + if s.permanent { + s.registerFallback() + } return nil } @@ -394,13 +486,7 @@ func (s *DefaultServer) SetFirewall(fw Firewall) { // Stop stops the server func (s *DefaultServer) Stop() { - s.probeMu.Lock() - if s.probeCancel != nil { - s.probeCancel() - } s.ctxCancel() - s.probeMu.Unlock() - s.probeWg.Wait() s.shutdownWg.Wait() s.mux.Lock() @@ -411,6 +497,13 @@ func (s *DefaultServer) Stop() { } clear(s.extraDomains) + + // Clear health projection state so a subsequent Start doesn't + // inherit sticky flags (notably everHealthy) that would bypass + // the grace window during the next peer handshake. + s.healthProjectMu.Lock() + s.nsGroupProj = nil + s.healthProjectMu.Unlock() } func (s *DefaultServer) disableDNS() (retErr error) { @@ -424,10 +517,9 @@ func (s *DefaultServer) disableDNS() (retErr error) { return nil } - // Deregister original nameservers if they were registered as fallback - if srvs, ok := s.hostManager.(hostManagerWithOriginalNS); ok && len(srvs.getOriginalNameservers()) > 0 { - log.Debugf("deregistering original nameservers as fallback handlers") - s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + if s.fallbackHandler != nil { + log.Debugf("deregistering fallback handlers") + s.clearFallback() } if err := s.hostManager.restoreHostDNS(); err != nil { @@ -441,27 +533,16 @@ func (s *DefaultServer) disableDNS() (retErr error) { return nil } -// OnUpdatedHostDNSServer update the DNS servers addresses for root zones -// It will be applied if the mgm server do not enforce DNS settings for root zone +// OnUpdatedHostDNSServer updates the fallback DNS upstreams. Called by Android +// outside the engine's sync mux when the OS reports a network change, so it +// takes s.mux to serialize against host manager swaps in Initialize/enableDNS. func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) { s.hostsDNSHolder.set(hostsDnsList) - - // Check if there's any root handler - var hasRootHandler bool - for _, handler := range s.dnsMuxMap { - if handler.domain == nbdns.RootZone { - hasRootHandler = true - break - } - } - - if hasRootHandler { - log.Debugf("on new host DNS config but skip to apply it") - return - } - log.Debugf("update host DNS settings: %+v", hostsDnsList) - s.addHostRootZone() + + s.mux.Lock() + defer s.mux.Unlock() + s.registerFallback() } // UpdateDNSServer processes an update received from the management service @@ -520,69 +601,6 @@ func (s *DefaultServer) SearchDomains() []string { return searchDomains } -// ProbeAvailability tests each upstream group's servers for availability -// and deactivates the group if no server responds. -// If a previous probe is still running, it will be cancelled before starting a new one. -func (s *DefaultServer) ProbeAvailability() { - if val := os.Getenv(envSkipDNSProbe); val != "" { - skipProbe, err := strconv.ParseBool(val) - if err != nil { - log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err) - } - if skipProbe { - log.Infof("skipping DNS probe due to %s", envSkipDNSProbe) - return - } - } - - s.probeMu.Lock() - - // don't start probes on a stopped server - if s.ctx.Err() != nil { - s.probeMu.Unlock() - return - } - - // cancel any running probe - if s.probeCancel != nil { - s.probeCancel() - s.probeCancel = nil - } - - // wait for the previous probe goroutines to finish while holding - // the mutex so no other caller can start a new probe concurrently - s.probeWg.Wait() - - // start a new probe - probeCtx, probeCancel := context.WithCancel(s.ctx) - s.probeCancel = probeCancel - - s.probeWg.Add(1) - defer s.probeWg.Done() - - // Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers. - s.mux.Lock() - handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap)) - for _, mux := range s.dnsMuxMap { - handlers = append(handlers, mux.handler) - } - s.mux.Unlock() - - var wg sync.WaitGroup - for _, handler := range handlers { - wg.Add(1) - go func(h handlerWithStop) { - defer wg.Done() - h.ProbeAvailability(probeCtx) - }(handler) - } - - s.probeMu.Unlock() - - wg.Wait() - probeCancel() -} - func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { s.mux.Lock() defer s.mux.Unlock() @@ -746,19 +764,17 @@ func (s *DefaultServer) applyHostConfig() { s.currentConfigHash = hash } - s.registerFallback(config) + s.registerFallback() } // registerFallback registers original nameservers as low-priority fallback handlers. -func (s *DefaultServer) registerFallback(config HostDNSConfig) { - hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) - if !ok { - return - } - - originalNameservers := hostMgrWithNS.getOriginalNameservers() +// Replaces and Stop()s the previously-registered fallback handler so its +// context is released rather than leaked until GC. +func (s *DefaultServer) registerFallback() { + originalNameservers := s.hostManager.getOriginalNameservers() if len(originalNameservers) == 0 { - s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler") + s.clearFallback() return } @@ -775,21 +791,28 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { log.Errorf("failed to create upstream resolver for original nameservers: %v", err) return } - handler.routeMatch = s.routeMatch + handler.selectedRoutes = s.selectedRoutes + var servers []netip.AddrPort for _, ns := range originalNameservers { - if ns == config.ServerIP { - log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) - continue - } - - addrPort := netip.AddrPortFrom(ns, DefaultPort) - handler.upstreamServers = append(handler.upstreamServers, addrPort) + servers = append(servers, netip.AddrPortFrom(ns, DefaultPort)) } - handler.deactivate = func(error) { /* always active */ } - handler.reactivate = func() { /* always active */ } + handler.addRace(servers) + prev := s.fallbackHandler + s.fallbackHandler = handler s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) + if prev != nil { + prev.Stop() + } +} + +func (s *DefaultServer) clearFallback() { + s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + if s.fallbackHandler != nil { + s.fallbackHandler.Stop() + s.fallbackHandler = nil + } } func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) { @@ -847,100 +870,99 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam groupedNS := groupNSGroupsByDomain(nameServerGroups) for _, domainGroup := range groupedNS { - basePriority := PriorityUpstream + priority := PriorityUpstream if domainGroup.domain == nbdns.RootZone { - basePriority = PriorityDefault + priority = PriorityDefault } - updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority) + update, err := s.buildMergedDomainHandler(domainGroup, priority) if err != nil { + if errors.Is(err, errNoUsableNameservers) { + log.Errorf("no usable nameservers for domain=%s", domainGroup.domain) + continue + } return nil, err } - muxUpdates = append(muxUpdates, updates...) + muxUpdates = append(muxUpdates, *update) } return muxUpdates, nil } -func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) { - var muxUpdates []handlerWrapper +// buildMergedDomainHandler merges every nameserver group that targets the +// same domain into one handler whose inner groups are raced in parallel. +func (s *DefaultServer) buildMergedDomainHandler(domainGroup nsGroupsByDomain, priority int) (*handlerWrapper, error) { + handler, err := newUpstreamResolver( + s.ctx, + s.wgInterface, + s.statusRecorder, + s.hostsDNSHolder, + domain.Domain(domainGroup.domain), + ) + if err != nil { + return nil, fmt.Errorf("create upstream resolver: %v", err) + } + handler.selectedRoutes = s.selectedRoutes - for i, nsGroup := range domainGroup.groups { - // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts - priority := basePriority - i - - // Check if we're about to overlap with the next priority tier - if s.leaksPriority(domainGroup, basePriority, priority) { - break - } - - log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority) - handler, err := newUpstreamResolver( - s.ctx, - s.wgInterface, - s.statusRecorder, - s.hostsDNSHolder, - domainGroup.domain, - ) - if err != nil { - return nil, fmt.Errorf("create upstream resolver: %v", err) - } - handler.routeMatch = s.routeMatch - - for _, ns := range nsGroup.NameServers { - if ns.NSType != nbdns.UDPNameServerType { - log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", - ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) - continue - } - - if ns.IP == s.service.RuntimeIP() { - log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP) - continue - } - - handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort()) - } - - if len(handler.upstreamServers) == 0 { - handler.Stop() - log.Errorf("received a nameserver group with an invalid nameserver list") + for _, nsGroup := range domainGroup.groups { + servers := s.filterNameServers(nsGroup.NameServers) + if len(servers) == 0 { + log.Warnf("nameserver group for domain=%s yielded no usable servers, skipping", domainGroup.domain) continue } - - // when upstream fails to resolve domain several times over all it servers - // it will calls this hook to exclude self from the configuration and - // reapply DNS settings, but it not touch the original configuration and serial number - // because it is temporal deactivation until next try - // - // after some period defined by upstream it tries to reactivate self by calling this hook - // everything we need here is just to re-apply current configuration because it already - // contains this upstream settings (temporal deactivation not removed it) - handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority) - - muxUpdates = append(muxUpdates, handlerWrapper{ - domain: domainGroup.domain, - handler: handler, - priority: priority, - }) + handler.addRace(servers) } - return muxUpdates, nil + if len(handler.upstreamServers) == 0 { + handler.Stop() + return nil, errNoUsableNameservers + } + + log.Debugf("creating merged handler for domain=%s with %d group(s) priority=%d", domainGroup.domain, len(handler.upstreamServers), priority) + + return &handlerWrapper{ + domain: domainGroup.domain, + handler: handler, + priority: priority, + }, nil } -func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool { - if basePriority == PriorityUpstream && priority <= PriorityDefault { - log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityUpstream-PriorityDefault) - return true - } - if basePriority == PriorityDefault && priority <= PriorityFallback { - log.Warnf("too many handlers for domain=%s, would overlap with fallback priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityDefault-PriorityFallback) - return true +func (s *DefaultServer) filterNameServers(nameServers []nbdns.NameServer) []netip.AddrPort { + var out []netip.AddrPort + for _, ns := range nameServers { + if ns.NSType != nbdns.UDPNameServerType { + log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", + ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) + continue + } + if ns.IP == s.service.RuntimeIP() { + log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP) + continue + } + out = append(out, ns.AddrPort()) } + return out +} - return false +// usableNameServers returns the subset of nameServers the handler would +// actually query. Matches filterNameServers without the warning logs, so +// it's safe to call on every health-projection tick. +func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []netip.AddrPort { + var runtimeIP netip.Addr + if s.service != nil { + runtimeIP = s.service.RuntimeIP() + } + var out []netip.AddrPort + for _, ns := range nameServers { + if ns.NSType != nbdns.UDPNameServerType { + continue + } + if runtimeIP.IsValid() && ns.IP == runtimeIP { + continue + } + out = append(out, ns.AddrPort()) + } + return out } func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { @@ -951,175 +973,356 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { } muxUpdateMap := make(registeredHandlerMap) - var containsRootUpdate bool for _, update := range muxUpdates { - if update.domain == nbdns.RootZone { - containsRootUpdate = true - } s.registerHandler([]string{update.domain}, update.handler, update.priority) muxUpdateMap[update.handler.ID()] = update } - // If there's no root update and we had a root handler, restore it - if !containsRootUpdate { - for _, existing := range s.dnsMuxMap { - if existing.domain == nbdns.RootZone { - s.addHostRootZone() - break - } - } - } - s.dnsMuxMap = muxUpdateMap } -// upstreamCallbacks returns two functions, the first one is used to deactivate -// the upstream resolver from the configuration, the second one is used to -// reactivate it. Not allowed to call reactivate before deactivate. -func (s *DefaultServer) upstreamCallbacks( - nsGroup *nbdns.NameServerGroup, - handler dns.Handler, - priority int, -) (deactivate func(error), reactivate func()) { - var removeIndex map[string]int - deactivate = func(err error) { - s.mux.Lock() - defer s.mux.Unlock() - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Info("Temporarily deactivating nameservers group due to timeout") - - removeIndex = make(map[string]int) - for _, domain := range nsGroup.Domains { - removeIndex[domain] = -1 - } - if nsGroup.Primary { - removeIndex[nbdns.RootZone] = -1 - s.currentConfig.RouteAll = false - s.deregisterHandler([]string{nbdns.RootZone}, priority) - } - - for i, item := range s.currentConfig.Domains { - if _, found := removeIndex[item.Domain]; found { - s.currentConfig.Domains[i].Disabled = true - s.deregisterHandler([]string{item.Domain}, priority) - removeIndex[item.Domain] = i - } - } - - // Always apply host config when nameserver goes down, regardless of batch mode - s.applyHostConfig() - - go func() { - if err := s.stateManager.PersistState(s.ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } - }() - - if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { - s.addHostRootZone() - } - - s.updateNSState(nsGroup, err, false) - } - - reactivate = func() { - s.mux.Lock() - defer s.mux.Unlock() - - for domain, i := range removeIndex { - if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain { - continue - } - s.currentConfig.Domains[i].Disabled = false - s.registerHandler([]string{domain}, handler, priority) - } - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Debug("reactivate temporary disabled nameserver group") - - if nsGroup.Primary { - s.currentConfig.RouteAll = true - s.registerHandler([]string{nbdns.RootZone}, handler, priority) - } - - // Always apply host config when nameserver reactivates, regardless of batch mode - s.applyHostConfig() - - s.updateNSState(nsGroup, nil, true) - } - return -} - -func (s *DefaultServer) addHostRootZone() { - hostDNSServers := s.hostsDNSHolder.get() - if len(hostDNSServers) == 0 { - log.Debug("no host DNS servers available, skipping root zone handler creation") - return - } - - handler, err := newUpstreamResolver( - s.ctx, - s.wgInterface, - s.statusRecorder, - s.hostsDNSHolder, - nbdns.RootZone, - ) - if err != nil { - log.Errorf("unable to create a new upstream resolver, error: %v", err) - return - } - handler.routeMatch = s.routeMatch - - handler.upstreamServers = maps.Keys(hostDNSServers) - handler.deactivate = func(error) {} - handler.reactivate = func() {} - - s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) -} - +// updateNSGroupStates records the new group set and pokes the refresher. +// Must hold s.mux; projection runs async (see refreshHealth for why). func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { - var states []peer.NSGroupState + s.nsGroups = groups + select { + case s.healthRefresh <- struct{}{}: + default: + } +} - for _, group := range groups { - var servers []netip.AddrPort - for _, ns := range group.NameServers { - servers = append(servers, ns.AddrPort()) +// refreshHealth runs one projection cycle. Must not be called while +// holding s.mux: the route callbacks re-enter routemanager's lock. +func (s *DefaultServer) refreshHealth() { + s.mux.Lock() + groups := s.nsGroups + merged := s.collectUpstreamHealth() + selFn := s.selectedRoutes + actFn := s.activeRoutes + s.mux.Unlock() + + var selected, active route.HAMap + if selFn != nil { + selected = selFn() + } + if actFn != nil { + active = actFn() + } + + s.projectNSGroupHealth(nsHealthSnapshot{ + groups: groups, + merged: merged, + selected: selected, + active: active, + }) +} + +// projectNSGroupHealth applies the emission rules to the snapshot and +// publishes the resulting NSGroupStates. Serialized by healthProjectMu, +// lock-free wrt s.mux. +// +// Rules: +// - Healthy: emit recovery iff warningActive; set everHealthy. +// - Unhealthy: stamp unhealthySince on streak start; emit warning +// iff any of immediate / everHealthy / elapsed >= effective delay. +// - Undecided: no-op. +// +// "Immediate" means the group has at least one upstream that's public +// or overlay+Connected: no peer-startup race to wait out. +func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) { + if s.statusRecorder == nil { + return + } + + s.healthProjectMu.Lock() + defer s.healthProjectMu.Unlock() + + if s.nsGroupProj == nil { + s.nsGroupProj = make(map[nsGroupID]*nsGroupProj) + } + + now := time.Now() + delay := s.warningDelay(haMapRouteCount(snap.selected)) + states := make([]peer.NSGroupState, 0, len(snap.groups)) + seen := make(map[nsGroupID]struct{}, len(snap.groups)) + for _, group := range snap.groups { + servers := s.usableNameServers(group.NameServers) + if len(servers) == 0 { + continue + } + verdict, groupErr := evaluateNSGroupHealth(snap.merged, servers, now) + id := generateGroupKey(group) + seen[id] = struct{}{} + + immediate := s.groupHasImmediateUpstream(servers, snap) + + p, known := s.nsGroupProj[id] + if !known { + p = &nsGroupProj{} + s.nsGroupProj[id] = p } - state := peer.NSGroupState{ - ID: generateGroupKey(group), + enabled := true + switch verdict { + case nsVerdictHealthy: + enabled = s.projectHealthy(p, servers) + case nsVerdictUnhealthy: + enabled = s.projectUnhealthy(p, servers, immediate, now, delay) + case nsVerdictUndecided: + // Stay Available until evidence says otherwise, unless a + // warning is already active for this group. Also clear any + // prior Unhealthy streak so a later Unhealthy verdict starts + // a fresh grace window rather than inheriting a stale one. + p.unhealthySince = time.Time{} + enabled = !p.warningActive + groupErr = nil + } + + states = append(states, peer.NSGroupState{ + ID: string(id), Servers: servers, Domains: group.Domains, - // The probe will determine the state, default enabled - Enabled: true, - Error: nil, - } - states = append(states, state) + Enabled: enabled, + Error: groupErr, + }) } - s.statusRecorder.UpdateDNSStates(states) -} - -func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) { - states := s.statusRecorder.GetDNSStates() - id := generateGroupKey(nsGroup) - for i, state := range states { - if state.ID == id { - states[i].Enabled = enabled - states[i].Error = err - break + for id := range s.nsGroupProj { + if _, ok := seen[id]; !ok { + delete(s.nsGroupProj, id) } } s.statusRecorder.UpdateDNSStates(states) } -func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { - var servers []string +// projectHealthy records a healthy tick on p and publishes a recovery +// event iff a warning was active for the current streak. Returns the +// Enabled flag to record in NSGroupState. +func (s *DefaultServer) projectHealthy(p *nsGroupProj, servers []netip.AddrPort) bool { + p.everHealthy = true + p.unhealthySince = time.Time{} + if !p.warningActive { + return true + } + log.Debugf("DNS health: group [%s] recovered, emitting event", joinAddrPorts(servers)) + s.statusRecorder.PublishEvent( + proto.SystemEvent_INFO, + proto.SystemEvent_DNS, + "Nameserver group recovered", + "DNS servers are reachable again.", + map[string]string{"upstreams": joinAddrPorts(servers)}, + ) + p.warningActive = false + return true +} + +// projectUnhealthy records an unhealthy tick on p, publishes the +// warning when the emission rules fire, and returns the Enabled flag +// to record in NSGroupState. +func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPort, immediate bool, now time.Time, delay time.Duration) bool { + streakStart := p.unhealthySince.IsZero() + if streakStart { + p.unhealthySince = now + } + reason := unhealthyEmitReason(immediate, p.everHealthy, now.Sub(p.unhealthySince), delay) + switch { + case reason != "" && !p.warningActive: + log.Debugf("DNS health: group [%s] unreachable, emitting event (reason=%s)", joinAddrPorts(servers), reason) + s.statusRecorder.PublishEvent( + proto.SystemEvent_WARNING, + proto.SystemEvent_DNS, + "Nameserver group unreachable", + "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", + map[string]string{"upstreams": joinAddrPorts(servers)}, + ) + p.warningActive = true + case streakStart && reason == "": + // One line per streak, not per tick. + log.Debugf("DNS health: group [%s] unreachable but holding warning for up to %v (overlay-routed, no connected peer)", joinAddrPorts(servers), delay) + } + return false +} + +// warningDelay returns the grace window for the given selected-route +// count. Scales gently: +1s per 100 routes, capped by +// warningDelayBonusCap. Parallel handshakes mean handshake time grows +// much slower than route count, so linear scaling would overcorrect. +// +// TODO: revisit the scaling curve with real-world data — the current +// values are a reasonable starting point, not a measured fit. +func (s *DefaultServer) warningDelay(routeCount int) time.Duration { + bonus := time.Duration(routeCount/100) * time.Second + if bonus > warningDelayBonusCap { + bonus = warningDelayBonusCap + } + return s.warningDelayBase + bonus +} + +// groupHasImmediateUpstream reports whether the group has at least one +// upstream in a classification that bypasses the grace window: public +// (outside the overlay range and not routed), or overlay/routed with a +// Connected peer. +// +// TODO(ipv6): include the v6 overlay prefix once it's plumbed in. +func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap nsHealthSnapshot) bool { + var overlayV4 netip.Prefix + if s.wgInterface != nil { + overlayV4 = s.wgInterface.Address().Network + } + for _, srv := range servers { + addr := srv.Addr().Unmap() + overlay := overlayV4.IsValid() && overlayV4.Contains(addr) + selMatched, selDynamic := haMapContains(snap.selected, addr) + // Treat an unknown (dynamic selected route) as possibly routed: + // the upstream might reach through a dynamic route whose Network + // hasn't resolved yet, and classifying as public would bypass + // the startup grace window. + routed := selMatched || selDynamic + if !overlay && !routed { + return true + } + if actMatched, _ := haMapContains(snap.active, addr); actMatched { + return true + } + } + return false +} + +// collectUpstreamHealth merges health snapshots across handlers, keeping +// the most recent success and failure per upstream when an address appears +// in more than one handler. +func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth { + merged := make(map[netip.AddrPort]UpstreamHealth) + for _, entry := range s.dnsMuxMap { + reporter, ok := entry.handler.(upstreamHealthReporter) + if !ok { + continue + } + for addr, h := range reporter.UpstreamHealth() { + existing, have := merged[addr] + if !have { + merged[addr] = h + continue + } + if h.LastOk.After(existing.LastOk) { + existing.LastOk = h.LastOk + } + if h.LastFail.After(existing.LastFail) { + existing.LastFail = h.LastFail + existing.LastErr = h.LastErr + } + merged[addr] = existing + } + } + return merged +} + +func (s *DefaultServer) startHealthRefresher() { + s.shutdownWg.Add(1) + go func() { + defer s.shutdownWg.Done() + ticker := time.NewTicker(nsGroupHealthRefreshInterval) + defer ticker.Stop() + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + case <-s.healthRefresh: + } + s.refreshHealth() + } + }() +} + +// evaluateNSGroupHealth decides a group's verdict from query records +// alone. Per upstream, the most-recent-in-lookback observation wins. +// Group is Healthy if any upstream is fresh-working, Unhealthy if any +// is fresh-broken with no fresh-working sibling, Undecided otherwise. +func evaluateNSGroupHealth(merged map[netip.AddrPort]UpstreamHealth, servers []netip.AddrPort, now time.Time) (nsGroupVerdict, error) { + anyWorking := false + anyBroken := false + var mostRecentFail time.Time + var mostRecentErr string + + for _, srv := range servers { + h, ok := merged[srv] + if !ok { + continue + } + switch classifyUpstreamHealth(h, now) { + case upstreamFresh: + anyWorking = true + case upstreamBroken: + anyBroken = true + if h.LastFail.After(mostRecentFail) { + mostRecentFail = h.LastFail + mostRecentErr = h.LastErr + } + } + } + + if anyWorking { + return nsVerdictHealthy, nil + } + if anyBroken { + if mostRecentErr == "" { + return nsVerdictUnhealthy, nil + } + return nsVerdictUnhealthy, errors.New(mostRecentErr) + } + return nsVerdictUndecided, nil +} + +// upstreamClassification is the per-upstream verdict within healthLookback. +type upstreamClassification int + +const ( + upstreamStale upstreamClassification = iota + upstreamFresh + upstreamBroken +) + +// classifyUpstreamHealth compares the last ok and last fail timestamps +// against healthLookback and returns which one (if any) counts. Fresh +// wins when both are in-window and ok is newer; broken otherwise. +func classifyUpstreamHealth(h UpstreamHealth, now time.Time) upstreamClassification { + okRecent := !h.LastOk.IsZero() && now.Sub(h.LastOk) <= healthLookback + failRecent := !h.LastFail.IsZero() && now.Sub(h.LastFail) <= healthLookback + switch { + case okRecent && failRecent: + if h.LastOk.After(h.LastFail) { + return upstreamFresh + } + return upstreamBroken + case okRecent: + return upstreamFresh + case failRecent: + return upstreamBroken + } + return upstreamStale +} + +func joinAddrPorts(servers []netip.AddrPort) string { + parts := make([]string, 0, len(servers)) + for _, s := range servers { + parts = append(parts, s.String()) + } + return strings.Join(parts, ", ") +} + +// generateGroupKey returns a stable identity for an NS group so health +// state (everHealthy / warningActive) survives reorderings in the +// configured nameserver or domain lists. +func generateGroupKey(nsGroup *nbdns.NameServerGroup) nsGroupID { + servers := make([]string, 0, len(nsGroup.NameServers)) for _, ns := range nsGroup.NameServers { servers = append(servers, ns.AddrPort().String()) } - return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) + slices.Sort(servers) + domains := slices.Clone(nsGroup.Domains) + slices.Sort(domains) + return nsGroupID(fmt.Sprintf("%v_%v", servers, domains)) } // groupNSGroupsByDomain groups nameserver groups by their match domains @@ -1161,6 +1364,21 @@ func toZone(d domain.Domain) domain.Domain { ) } +// unhealthyEmitReason returns the tag of the rule that fires the +// warning now, or "" if the group is still inside its grace window. +func unhealthyEmitReason(immediate, everHealthy bool, elapsed, delay time.Duration) string { + switch { + case immediate: + return "immediate" + case everHealthy: + return "ever-healthy" + case elapsed >= delay: + return "grace-elapsed" + default: + return "" + } +} + // PopulateManagementDomain populates the DNS cache with management domain func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { if s.mgmtCacheResolver != nil { diff --git a/client/internal/dns/server_android.go b/client/internal/dns/server_android.go index 7ca12d69d..b2cb26f65 100644 --- a/client/internal/dns/server_android.go +++ b/client/internal/dns/server_android.go @@ -1,5 +1,5 @@ package dns func (s *DefaultServer) initialize() (manager hostManager, err error) { - return newHostManager() + return newHostManager(s.hostsDNSHolder) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 1026a29fc..722c2abd7 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -6,7 +6,7 @@ import ( "net" "net/netip" "os" - "strings" + "runtime" "testing" "time" @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -31,8 +32,10 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -101,16 +104,17 @@ func init() { formatter.SetTextFormatter(log.StandardLogger()) } -func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { +func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase { var srvs []netip.AddrPort for _, srv := range servers { srvs = append(srvs, srv.AddrPort()) } - return &upstreamResolverBase{ - domain: domain, - upstreamServers: srvs, - cancel: func() {}, + u := &upstreamResolverBase{ + domain: domain.Domain(d), + cancel: func() {}, } + u.addRace(srvs) + return u } func TestUpdateDNSServer(t *testing.T) { @@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) { } } -func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { - hostManager := &mockHostConfigurator{} - server := DefaultServer{ - ctx: context.Background(), - service: NewServiceViaMemory(&mocWGIface{}), - localResolver: local.NewResolver(), - handlerChain: NewHandlerChain(), - hostManager: hostManager, - currentConfig: HostDNSConfig{ - Domains: []DomainConfig{ - {false, "domain0", false}, - {false, "domain1", false}, - {false, "domain2", false}, - }, - }, - statusRecorder: peer.NewRecorder("mgm"), - } - - var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { - domains := []string{} - for _, item := range config.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - domainsUpdate = strings.Join(domains, ",") - return nil - } - - deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{ - Domains: []string{"domain1"}, - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, - }, - }, nil, 0) - - deactivate(nil) - expected := "domain0,domain2" - domains := []string{} - for _, item := range server.currentConfig.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - got := strings.Join(domains, ",") - if expected != got { - t.Errorf("expected domains list: %q, got %q", expected, got) - } - - reactivate() - expected = "domain0,domain1,domain2" - domains = []string{} - for _, item := range server.currentConfig.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - got = strings.Join(domains, ",") - if expected != got { - t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate) - } -} - func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { + skipUnlessAndroid(t) wgIFace, err := createWgInterfaceWithBind(t) if err != nil { t.Fatal("failed to initialize wg interface") @@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { } func TestDNSPermanent_updateUpstream(t *testing.T) { + skipUnlessAndroid(t) wgIFace, err := createWgInterfaceWithBind(t) if err != nil { t.Fatal("failed to initialize wg interface") @@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) { } func TestDNSPermanent_matchOnly(t *testing.T) { + skipUnlessAndroid(t) wgIFace, err := createWgInterfaceWithBind(t) if err != nil { t.Fatal("failed to initialize wg interface") @@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) { } } +// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path, +// which only matches a real production setup on android (NewDefaultServerPermanentUpstream +// + androidHostManager). On non-android the desktop host manager replaces it +// during Initialize and the assertion stops making sense. Skipped here until we +// have an android CI runner. +func skipUnlessAndroid(t *testing.T) { + t.Helper() + if runtime.GOOS != "android" { + t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS") + } +} + func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { t.Helper() ov := os.Getenv("NB_WG_KERNEL_DISABLED") @@ -1065,7 +1017,6 @@ type mockHandler struct { func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} func (m *mockHandler) Stop() {} -func (m *mockHandler) ProbeAvailability(context.Context) {} func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} @@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) { assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) } +// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple +// admin-defined nameserver groups targeting the same domain collapse into a +// single handler with each group preserved as a sequential inner list. +func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) { + wgInterface := &mocWGIface{} + service := NewServiceViaMemory(wgInterface) + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: wgInterface, + service: service, + localResolver: local.NewResolver(), + handlerChain: NewHandlerChain(), + hostManager: &noopHostConfigurator{}, + dnsMuxMap: make(registeredHandlerMap), + } + + groups := []*nbdns.NameServerGroup{ + { + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{"example.com"}, + }, + { + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{"example.com"}, + }, + } + + muxUpdates, err := server.buildUpstreamHandlerUpdate(groups) + require.NoError(t, err) + require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler") + assert.Equal(t, "example.com", muxUpdates[0].domain) + assert.Equal(t, PriorityUpstream, muxUpdates[0].priority) + + handler := muxUpdates[0].handler.(*upstreamResolver) + require.Len(t, handler.upstreamServers, 2, "handler should have two groups") + assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0]) + assert.Equal(t, upstreamRace{ + netip.MustParseAddrPort("192.0.2.2:53"), + netip.MustParseAddrPort("192.0.2.3:53"), + }, handler.upstreamServers[1]) +} + +// TestEvaluateNSGroupHealth covers the records-only verdict. The gate +// (overlay route selected-but-no-active-peer) is intentionally NOT an +// input to the evaluator anymore: the verdict drives the Enabled flag, +// which must always reflect what we actually observed. Gate-aware event +// suppression is tested separately in the projection test. +// +// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail, +// stale Ok, Ok newer than Fail, Fail newer than Ok}. +// Group verdict: any fresh-working → Healthy; any fresh-broken with no +// fresh-working → Unhealthy; otherwise Undecided. +func TestEvaluateNSGroupHealth(t *testing.T) { + now := time.Now() + a := netip.MustParseAddrPort("192.0.2.1:53") + b := netip.MustParseAddrPort("192.0.2.2:53") + + recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)} + recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"} + staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)} + staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"} + okThenFail := UpstreamHealth{ + LastOk: now.Add(-10 * time.Second), + LastFail: now.Add(-1 * time.Second), + LastErr: "timeout", + } + failThenOk := UpstreamHealth{ + LastOk: now.Add(-1 * time.Second), + LastFail: now.Add(-10 * time.Second), + LastErr: "timeout", + } + + tests := []struct { + name string + health map[netip.AddrPort]UpstreamHealth + servers []netip.AddrPort + wantVerdict nsGroupVerdict + wantErrSubst string + }{ + { + name: "no record, undecided", + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "fresh success, healthy", + health: map[netip.AddrPort]UpstreamHealth{a: recentOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "fresh failure, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{a: recentFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "only stale success, undecided", + health: map[netip.AddrPort]UpstreamHealth{a: staleOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "only stale failure, undecided", + health: map[netip.AddrPort]UpstreamHealth{a: staleFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "both fresh, fail newer, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{a: okThenFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "both fresh, ok newer, healthy", + health: map[netip.AddrPort]UpstreamHealth{a: failThenOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "two upstreams, one success wins", + health: map[netip.AddrPort]UpstreamHealth{ + a: recentFail, + b: recentOk, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "two upstreams, one fail one unseen, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{ + a: recentFail, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "two upstreams, all recent failures, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{ + a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"}, + b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"}, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "SERVFAIL", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now) + assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch") + if tc.wantErrSubst != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrSubst) + } else { + assert.NoError(t, err) + } + }) + } +} + +// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed +// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates +// without spinning up real handlers. +type healthStubHandler struct { + health map[netip.AddrPort]UpstreamHealth +} + +func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} +func (h *healthStubHandler) Stop() {} +func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" } +func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth { + return h.health +} + +// TestProjection_SteadyStateIsSilent guards against duplicate events: +// while a group stays Unhealthy tick after tick, only the first +// Unhealthy transition may emit. Same for staying Healthy. +func TestProjection_SteadyStateIsSilent(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "first fail emits warning") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.tick() + fx.expectNoEvent("staying unhealthy must not re-emit") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "recovery on transition") + + fx.tick() + fx.tick() + fx.expectNoEvent("staying healthy must not re-emit") +} + +// projTestFixture is the common setup for the projection tests: a +// single-upstream group whose route classification the test can flip by +// assigning to selected/active. Callers drive failures/successes by +// mutating stub.health and calling refreshHealth. +type projTestFixture struct { + t *testing.T + recorder *peer.Status + events <-chan *proto.SystemEvent + server *DefaultServer + stub *healthStubHandler + group *nbdns.NameServerGroup + srv netip.AddrPort + selected route.HAMap + active route.HAMap +} + +func newProjTestFixture(t *testing.T) *projTestFixture { + t.Helper() + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + + srv := netip.MustParseAddrPort("100.64.0.1:53") + fx := &projTestFixture{ + t: t, + recorder: recorder, + events: sub.Events(), + stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}}, + srv: srv, + group: &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, + }, + } + fx.server = &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return fx.selected }, + activeRoutes: func() route.HAMap { return fx.active }, + warningDelayBase: defaultWarningDelayBase, + } + fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream} + + fx.server.mux.Lock() + fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group}) + fx.server.mux.Unlock() + return fx +} + +func (f *projTestFixture) setHealth(h UpstreamHealth) { + f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h} +} + +func (f *projTestFixture) tick() []peer.NSGroupState { + f.server.refreshHealth() + return f.recorder.GetDNSStates() +} + +func (f *projTestFixture) expectNoEvent(why string) { + f.t.Helper() + select { + case evt := <-f.events: + f.t.Fatalf("unexpected event (%s): %+v", why, evt) + case <-time.After(100 * time.Millisecond): + } +} + +func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent { + f.t.Helper() + select { + case evt := <-f.events: + assert.Contains(f.t, evt.Message, substr, why) + return evt + case <-time.After(time.Second): + f.t.Fatalf("expected event (%s) with %q", why, substr) + return nil + } +} + +var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16") +var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}} + +// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream +// that is not inside any selected route (public DNS) fires the warning +// on the first Unhealthy tick, no grace period. +func TestProjection_PublicFailEmitsImmediately(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled) + fx.expectEvent("unreachable", "public DNS failure") +} + +// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2: +// the upstream is inside a selected route AND the route has a Connected +// peer. Tunnel is up, failure is real, emit immediately. +func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) { + fx := newProjTestFixture(t) + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled) + fx.expectEvent("unreachable", "overlay + connected failure") +} + +// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the +// upstream is routed but no peer is Connected (Connecting/Idle/missing). +// First tick: Unhealthy display, no warning. After the grace window +// elapses with no recovery, the warning fires. +func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) { + grace := 50 * time.Millisecond + fx := newProjTestFixture(t) + fx.server.warningDelayBase = grace + fx.selected = overlayMapForTest + // active stays nil: routed but not connected. + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled, "display must reflect failure even during grace window") + fx.expectNoEvent("first fail tick within grace window") + + time.Sleep(grace + 10*time.Millisecond) + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "warning after grace window") +} + +// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream +// whose address is inside the WireGuard overlay range but is not +// covered by any selected route (peer-to-peer DNS without an explicit +// route). Until a peer reports Connected for that address, startup +// failures must be held just like the routed case. +func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) { + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + + overlayPeer := netip.MustParseAddrPort("100.66.100.5:53") + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return nil }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: 50 * time.Millisecond, + } + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}}, + } + stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{ + overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}, + }} + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + select { + case evt := <-sub.Events(): + t.Fatalf("unexpected event during grace window: %+v", evt) + case <-time.After(100 * time.Millisecond): + } + + time.Sleep(60 * time.Millisecond) + stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}} + server.refreshHealth() + + select { + case evt := <-sub.Events(): + assert.Contains(t, evt.Message, "unreachable") + case <-time.After(time.Second): + t.Fatal("expected warning after grace window") + } +} + +// TestProjection_StopClearsHealthState verifies that Stop wipes the +// per-group projection state so a subsequent Start doesn't inherit +// sticky flags (notably everHealthy) that would bypass the grace +// window during the next peer handshake. +func TestProjection_StopClearsHealthState(t *testing.T) { + wgIface := &mocWGIface{} + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: wgIface, + service: NewServiceViaMemory(wgIface), + hostManager: &noopHostConfigurator{}, + extraDomains: map[domain.Domain]int{}, + dnsMuxMap: make(registeredHandlerMap), + statusRecorder: peer.NewRecorder("mgm"), + selectedRoutes: func() route.HAMap { return nil }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: defaultWarningDelayBase, + currentConfigHash: ^uint64(0), + } + server.ctx, server.ctxCancel = context.WithCancel(context.Background()) + + srv := netip.MustParseAddrPort("8.8.8.8:53") + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, + } + stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}} + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + server.healthProjectMu.Lock() + p, ok := server.nsGroupProj[generateGroupKey(group)] + server.healthProjectMu.Unlock() + require.True(t, ok, "projection state should exist after tick") + require.True(t, p.everHealthy, "tick with success must set everHealthy") + + server.Stop() + + server.healthProjectMu.Lock() + cleared := server.nsGroupProj == nil + server.healthProjectMu.Unlock() + assert.True(t, cleared, "Stop must clear nsGroupProj") +} + +// TestProjection_OverlayRecoversDuringGrace covers the happy path of +// rule 3: startup failures while the peer is handshaking, then the peer +// comes up and a query succeeds before the grace window elapses. No +// warning should ever have fired, and no recovery either. +func TestProjection_OverlayRecoversDuringGrace(t *testing.T) { + fx := newProjTestFixture(t) + fx.server.warningDelayBase = 200 * time.Millisecond + fx.selected = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectNoEvent("fail within grace, warning suppressed") + + fx.active = overlayMapForTest + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + states := fx.tick() + require.Len(t, states, 1) + assert.True(t, states[0].Enabled) + fx.expectNoEvent("recovery without prior warning must not emit") +} + +// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the +// whole design leans on: recovery events only appear when a warning +// event was actually emitted for the current streak. A Healthy verdict +// without a prior warning is silent, so the user never sees "recovered" +// out of thin air. +func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + states := fx.tick() + require.Len(t, states, 1) + assert.True(t, states[0].Enabled) + fx.expectNoEvent("first healthy tick should not recover anything") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "public fail emits immediately") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "recovery follows real warning") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "second cycle warning") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "second cycle recovery") +} + +// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group +// has ever been Healthy, subsequent failures skip the grace window even +// if classification says "routed + not connected". The system has +// proved it can work, so any new failure is real. +func TestProjection_EverHealthyOverridesDelay(t *testing.T) { + fx := newProjTestFixture(t) + // Large base so any emission must come from the everHealthy bypass, not elapsed time. + fx.server.warningDelayBase = time.Hour + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + // Establish "ever healthy". + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectNoEvent("first healthy tick") + + // Peer drops. Query fails. Routed + not connected → normally grace, + // but everHealthy flag bypasses it. + fx.active = nil + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "failure after ever-healthy must be immediate") +} + +// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff +// from the design discussion: once a group has been healthy, a brief +// reconnect that produces a failing tick will fire warning + recovery. +// This is by design: user-visible blips are accurate signal, not noise. +func TestProjection_ReconnectBlipEmitsPair(t *testing.T) { + fx := newProjTestFixture(t) + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "blip warning") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "blip recovery") +} + +// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream +// rule: a group with at least one public upstream is in the "immediate" +// category regardless of the other upstreams' routing, because the +// public one has no peer-startup excuse. Prevents public-DNS failures +// from being hidden behind a routed sibling. +func TestProjection_MixedGroupEmitsImmediately(t *testing.T) { + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + events := sub.Events() + + public := netip.MustParseAddrPort("8.8.8.8:53") + overlay := netip.MustParseAddrPort("100.64.0.1:53") + overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}} + + server := &DefaultServer{ + ctx: context.Background(), + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return overlayMap }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: time.Hour, + } + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{ + {IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())}, + {IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())}, + }, + } + stub := &healthStubHandler{ + health: map[netip.AddrPort]UpstreamHealth{ + public: {LastFail: time.Now(), LastErr: "servfail"}, + overlay: {LastFail: time.Now(), LastErr: "timeout"}, + }, + } + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + select { + case evt := <-events: + assert.Contains(t, evt.Message, "unreachable") + case <-time.After(time.Second): + t.Fatal("expected immediate warning because group contains a public upstream") + } +} + func TestDNSLoopPrevention(t *testing.T) { wgInterface := &mocWGIface{} service := NewServiceViaMemory(wgInterface) @@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) { if tt.expectedHandlers > 0 { handler := muxUpdates[0].handler.(*upstreamResolver) - assert.Len(t, handler.upstreamServers, len(tt.expectedServers)) + flat := handler.flatUpstreams() + assert.Len(t, flat, len(tt.expectedServers)) if tt.shouldFilterOwnIP { - for _, upstream := range handler.upstreamServers { + for _, upstream := range flat { assert.NotEqual(t, dnsServerIP, upstream.Addr()) } } for _, expected := range tt.expectedServers { found := false - for _, upstream := range handler.upstreamServers { + for _, upstream := range flat { if upstream.Addr() == expected { found = true break diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 573dff540..bd301e177 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "slices" "time" "github.com/godbus/dbus/v5" @@ -40,10 +41,17 @@ const ( ) type systemdDbusConfigurator struct { - dbusLinkObject dbus.ObjectPath - ifaceName string + dbusLinkObject dbus.ObjectPath + ifaceName string + wgIndex int + origNameservers []netip.Addr } +const ( + systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS" + systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute" +) + // the types below are based on dbus specification, each field is mapped to a dbus type // see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types // see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types @@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index) - return &systemdDbusConfigurator{ + c := &systemdDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), ifaceName: wgInterface, - }, nil + wgIndex: iface.Index, + } + + origNameservers, err := c.captureOriginalNameservers() + switch { + case err != nil: + log.Warnf("capture original nameservers from systemd-resolved: %v", err) + case len(origNameservers) == 0: + log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty") + default: + log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers) + } + c.origNameservers = origNameservers + return c, nil +} + +// captureOriginalNameservers reads per-link DNS from systemd-resolved for +// every default-route link except our own WG link. Non-default-route links +// (VPNs, docker bridges) are skipped because their upstreams wouldn't +// actually serve host queries. +func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("list interfaces: %w", err) + } + + seen := make(map[netip.Addr]struct{}) + var out []netip.Addr + for _, iface := range ifaces { + if !s.isCandidateLink(iface) { + continue + } + linkPath, err := getSystemdLinkPath(iface.Index) + if err != nil || !isSystemdLinkDefaultRoute(linkPath) { + continue + } + for _, addr := range readSystemdLinkDNS(linkPath) { + addr = normalizeSystemdAddr(addr, iface.Name) + if !addr.IsValid() { + continue + } + if _, dup := seen[addr]; dup { + continue + } + seen[addr] = struct{}{} + out = append(out, addr) + } + } + return out, nil +} + +func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool { + if iface.Index == s.wgIndex { + return false + } + if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 { + return false + } + return true +} + +// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches +// the link's iface name as zone for link-local v6 (Link.DNS strips it). +// Returns the zero Addr to signal "skip this entry". +func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr { + addr = addr.Unmap() + if !addr.IsValid() || addr.IsUnspecified() { + return netip.Addr{} + } + if addr.IsLinkLocalUnicast() { + return addr.WithZone(ifaceName) + } + return addr +} + +func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) { + obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) + if err != nil { + return "", fmt.Errorf("dbus resolve1: %w", err) + } + defer closeConn() + var p string + if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil { + return "", err + } + return dbus.ObjectPath(p), nil +} + +func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool { + obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath) + if err != nil { + return false + } + defer closeConn() + v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty) + if err != nil { + return false + } + b, ok := v.Value().(bool) + return ok && b +} + +func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath) + if err != nil { + return nil + } + defer closeConn() + v, err := obj.GetProperty(systemdDbusLinkDNSProperty) + if err != nil { + return nil + } + entries, ok := v.Value().([][]any) + if !ok { + return nil + } + var out []netip.Addr + for _, entry := range entries { + if len(entry) < 2 { + continue + } + raw, ok := entry[1].([]byte) + if !ok { + continue + } + addr, ok := netip.AddrFromSlice(raw) + if !ok { + continue + } + out = append(out, addr) + } + return out +} + +func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr { + return slices.Clone(s.origNameservers) } func (s *systemdDbusConfigurator) supportCustomPort() bool { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 39064f26c..a4f713d68 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -1,3 +1,32 @@ +// Package dns implements the client-side DNS stack: listener/service on the +// peer's tunnel address, handler chain that routes questions by domain and +// priority, and upstream resolvers that forward what remains to configured +// nameservers. +// +// # Upstream resolution and the race model +// +// When two or more nameserver groups target the same domain, DefaultServer +// merges them into one upstream handler whose state is: +// +// upstreamResolverBase +// └── upstreamServers []upstreamRace // one entry per source NS group +// └── []netip.AddrPort // primary, fallback, ... +// +// Each source nameserver group contributes one upstreamRace. Within a race +// upstreams are tried in order: the next is used only on failure (timeout, +// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops +// the walk. When more than one race exists, ServeDNS fans out one +// goroutine per race and returns the first valid answer, cancelling the +// rest. A handler with a single race skips the fan-out. +// +// # Health projection +// +// Query outcomes are recorded per-upstream in UpstreamHealth. The server +// periodically merges these snapshots across handlers and projects them +// into peer.NSGroupState. There is no active probing: a group is marked +// unhealthy only when every seen upstream has a recent failure and none +// has a recent success. Healthy→unhealthy fires a single +// SystemEvent_WARNING; steady-state refreshes do not duplicate it. package dns import ( @@ -11,11 +40,8 @@ import ( "slices" "strings" "sync" - "sync/atomic" "time" - "github.com/cenkalti/backoff/v4" - "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/tun/netstack" @@ -25,7 +51,8 @@ import ( "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) var currentMTU uint16 = iface.DefaultMTU @@ -67,15 +94,17 @@ const ( // Set longer than UpstreamTimeout to ensure context timeout takes precedence ClientTimeout = 5 * time.Second - reactivatePeriod = 30 * time.Second - probeTimeout = 2 * time.Second - // ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP // payload from the tunnel MTU. ipUDPHeaderSize = 60 + 8 -) -const testRecord = "com." + // raceMaxTotalTimeout caps the combined time spent walking all upstreams + // within one race, so a slow primary can't eat the whole race budget. + raceMaxTotalTimeout = 5 * time.Second + // raceMinPerUpstreamTimeout is the floor applied when dividing + // raceMaxTotalTimeout across upstreams within a race. + raceMinPerUpstreamTimeout = 2 * time.Second +) const ( protoUDP = "udp" @@ -84,6 +113,69 @@ const ( type dnsProtocolKey struct{} +type upstreamProtocolKey struct{} + +// upstreamProtocolResult holds the protocol used for the upstream exchange. +// Stored as a pointer in context so the exchange function can set it. +type upstreamProtocolResult struct { + protocol string +} + +type upstreamClient interface { + exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) +} + +type UpstreamResolver interface { + serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error) + upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) +} + +// upstreamRace is an ordered list of upstreams derived from one configured +// nameserver group. Order matters: the first upstream is tried first, the +// second only on failure, and so on. Multiple upstreamRace values coexist +// inside one resolver when overlapping nameserver groups target the same +// domain; those races run in parallel and the first valid answer wins. +type upstreamRace []netip.AddrPort + +// UpstreamHealth is the last query-path outcome for a single upstream, +// consumed by nameserver-group status projection. +type UpstreamHealth struct { + LastOk time.Time + LastFail time.Time + LastErr string +} + +type upstreamResolverBase struct { + ctx context.Context + cancel context.CancelFunc + upstreamClient upstreamClient + upstreamServers []upstreamRace + domain domain.Domain + upstreamTimeout time.Duration + + healthMu sync.RWMutex + health map[netip.AddrPort]*UpstreamHealth + + statusRecorder *peer.Status + // selectedRoutes returns the current set of client routes the admin + // has enabled. Called lazily from the query hot path when an upstream + // might need a tunnel-bound client (iOS) and from health projection. + selectedRoutes func() route.HAMap +} + +type upstreamFailure struct { + upstream netip.AddrPort + reason string +} + +type raceResult struct { + msg *dns.Msg + upstream netip.AddrPort + protocol string + ede string + failures []upstreamFailure +} + // contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context. func contextWithDNSProtocol(ctx context.Context, network string) context.Context { return context.WithValue(ctx, dnsProtocolKey{}, network) @@ -100,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string { return "" } -type upstreamProtocolKey struct{} - -// upstreamProtocolResult holds the protocol used for the upstream exchange. -// Stored as a pointer in context so the exchange function can set it. -type upstreamProtocolResult struct { - protocol string -} - -// contextWithupstreamProtocolResult stores a mutable result holder in the context. -func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { +// contextWithUpstreamProtocolResult stores a mutable result holder in the context. +func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { r := &upstreamProtocolResult{} return context.WithValue(ctx, upstreamProtocolKey{}, r), r } @@ -124,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) { } } -type upstreamClient interface { - exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) -} - -type UpstreamResolver interface { - serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error) - upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) -} - -type upstreamResolverBase struct { - ctx context.Context - cancel context.CancelFunc - upstreamClient upstreamClient - upstreamServers []netip.AddrPort - domain string - disabled bool - successCount atomic.Int32 - mutex sync.Mutex - reactivatePeriod time.Duration - upstreamTimeout time.Duration - wg sync.WaitGroup - - deactivate func(error) - reactivate func() - statusRecorder *peer.Status - routeMatch func(netip.Addr) bool -} - -type upstreamFailure struct { - upstream netip.AddrPort - reason string -} - -func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { +func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) return &upstreamResolverBase{ - ctx: ctx, - cancel: cancel, - domain: domain, - upstreamTimeout: UpstreamTimeout, - reactivatePeriod: reactivatePeriod, - statusRecorder: statusRecorder, + ctx: ctx, + cancel: cancel, + domain: d, + upstreamTimeout: UpstreamTimeout, + statusRecorder: statusRecorder, } } // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("Upstream %s", u.upstreamServers) + return fmt.Sprintf("Upstream %s", u.flatUpstreams()) } -// ID returns the unique handler ID +// ID returns the unique handler ID. Race groupings and within-race +// ordering are both part of the identity: [[A,B]] and [[A],[B]] query +// the same servers but with different semantics (serial fallback vs +// parallel race), so their handlers must not collide. func (u *upstreamResolverBase) ID() types.HandlerID { - servers := slices.Clone(u.upstreamServers) - slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) }) - hash := sha256.New() - hash.Write([]byte(u.domain + ":")) - for _, s := range servers { - hash.Write([]byte(s.String())) - hash.Write([]byte("|")) + hash.Write([]byte(u.domain.PunycodeString() + ":")) + for _, race := range u.upstreamServers { + hash.Write([]byte("[")) + for _, s := range race { + hash.Write([]byte(s.String())) + hash.Write([]byte("|")) + } + hash.Write([]byte("]")) } return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) } @@ -194,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool { } func (u *upstreamResolverBase) Stop() { - log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) + log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams()) u.cancel() +} - u.mutex.Lock() - u.wg.Wait() - u.mutex.Unlock() +// flatUpstreams is for logging and ID hashing only, not for dispatch. +func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort { + var out []netip.AddrPort + for _, g := range u.upstreamServers { + out = append(out, g...) + } + return out +} +// setSelectedRoutes swaps the accessor used to classify overlay-routed +// upstreams. Called when route sources are wired after the handler was +// built (permanent / iOS constructors). +func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) { + u.selectedRoutes = selected +} + +func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) { + if len(servers) == 0 { + return + } + u.upstreamServers = append(u.upstreamServers, slices.Clone(servers)) } // ServeDNS handles a DNS request @@ -242,82 +314,201 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { } func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { - timeout := u.upstreamTimeout - if len(u.upstreamServers) > 1 { - maxTotal := 5 * time.Second - minPerUpstream := 2 * time.Second - scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) - if scaledTimeout > minPerUpstream { - timeout = scaledTimeout - } else { - timeout = minPerUpstream - } + groups := u.upstreamServers + switch len(groups) { + case 0: + return false, nil + case 1: + return u.tryOnlyRace(ctx, w, r, groups[0], logger) + default: + return u.raceAll(ctx, w, r, groups, logger) + } +} + +func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) { + res := u.tryRace(ctx, r, group) + if res.msg == nil { + return false, res.failures + } + if res.ede != "" { + resutil.SetMeta(w, "ede", res.ede) + } + u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger) + return true, res.failures +} + +// raceAll runs one worker per group in parallel, taking the first valid +// answer and cancelling the rest. +func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) { + raceCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Buffer sized to len(groups) so workers never block on send, even + // after the coordinator has returned. + results := make(chan raceResult, len(groups)) + for _, g := range groups { + // tryRace clones the request per attempt, so workers never share + // a *dns.Msg and concurrent EDNS0 mutations can't race. + go func(g upstreamRace) { + results <- u.tryRace(raceCtx, r, g) + }(g) } var failures []upstreamFailure - for _, upstream := range u.upstreamServers { - if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil { - failures = append(failures, *failure) - } else { - return true, failures + for range groups { + select { + case res := <-results: + failures = append(failures, res.failures...) + if res.msg != nil { + if res.ede != "" { + resutil.SetMeta(w, "ede", res.ede) + } + u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger) + return true, failures + } + case <-ctx.Done(): + return false, failures } } return false, failures } -// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream. -func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { - var rm *dns.Msg - var t time.Duration - var err error +func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult { + timeout := u.upstreamTimeout + if len(group) > 1 { + // Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts + // still honor raceMinPerUpstreamTimeout as a floor for correctness + // on slow links, but the outer context ensures the combined walk + // cannot exceed the cap regardless of group size. + timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout) + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout) + defer cancel() + } + + var failures []upstreamFailure + for _, upstream := range group { + if ctx.Err() != nil { + return raceResult{failures: failures} + } + // Clone the request per attempt: the exchange path mutates EDNS0 + // options in-place, so reusing the same *dns.Msg across sequential + // upstreams would carry those mutations (e.g. a reduced UDP size) + // into the next attempt. + res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout) + if failure != nil { + failures = append(failures, *failure) + continue + } + res.failures = failures + return res + } + return raceResult{failures: failures} +} + +func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) { + ctx, cancel := context.WithTimeout(parentCtx, timeout) + defer cancel() + ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx) // Advertise EDNS0 so the upstream may include Extended DNS Errors // (RFC 8914) in failure responses; we use those to short-circuit // failover for definitive answers like DNSSEC validation failures. - // Operate on a copy so the inbound request is unchanged: a client that - // did not advertise EDNS0 must not see an OPT in the response. + // The caller already passed a per-attempt copy, so we can mutate r + // directly; hadEdns reflects the original client request's state and + // controls whether we strip the OPT from the response. hadEdns := r.IsEdns0() != nil - reqUp := r if !hadEdns { - reqUp = r.Copy() - reqUp.SetEdns0(upstreamUDPSize(), false) + r.SetEdns0(upstreamUDPSize(), false) } - var startTime time.Time - var upstreamProto *upstreamProtocolResult - func() { - ctx, cancel := context.WithTimeout(parentCtx, timeout) - defer cancel() - ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) - startTime = time.Now() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp) - }() + startTime := time.Now() + rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r) if err != nil { - return u.handleUpstreamError(err, upstream, startTime) + // A parent cancellation (e.g., another race won and the coordinator + // cancelled the losers) is not an upstream failure. Check both the + // error chain and the parent context: a transport may surface the + // cancellation as a read/deadline error rather than context.Canceled. + if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) { + return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"} + } + failure := u.handleUpstreamError(err, upstream, startTime) + u.markUpstreamFail(upstream, failure.reason) + return raceResult{}, failure } if rm == nil || !rm.Response { - return &upstreamFailure{upstream: upstream, reason: "no response"} + u.markUpstreamFail(upstream, "no response") + return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"} + } + + proto := "" + if upstreamProto != nil { + proto = upstreamProto.protocol } if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { if code, ok := nonRetryableEDE(rm); ok { - resutil.SetMeta(w, "ede", edeName(code)) if !hadEdns { stripOPT(rm) } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) - return nil + u.markUpstreamOk(upstream) + return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil } - return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} + reason := dns.RcodeToString[rm.Rcode] + u.markUpstreamFail(upstream, reason) + return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason} } if !hadEdns { stripOPT(rm) } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) - return nil + + u.markUpstreamOk(upstream) + return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil +} + +// healthEntry returns the mutable health record for addr, lazily creating +// the map and the entry. Caller must hold u.healthMu. +func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth { + if u.health == nil { + u.health = make(map[netip.AddrPort]*UpstreamHealth) + } + h := u.health[addr] + if h == nil { + h = &UpstreamHealth{} + u.health[addr] = h + } + return h +} + +func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) { + u.healthMu.Lock() + defer u.healthMu.Unlock() + h := u.healthEntry(addr) + h.LastOk = time.Now() + h.LastFail = time.Time{} + h.LastErr = "" +} + +func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) { + u.healthMu.Lock() + defer u.healthMu.Unlock() + h := u.healthEntry(addr) + h.LastFail = time.Now() + h.LastErr = reason +} + +// UpstreamHealth returns a snapshot of per-upstream query outcomes. +func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth { + u.healthMu.RLock() + defer u.healthMu.RUnlock() + out := make(map[netip.AddrPort]UpstreamHealth, len(u.health)) + for k, v := range u.health { + out[k] = *v + } + return out } // upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams, @@ -358,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add return &upstreamFailure{upstream: upstream, reason: reason} } -func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool { - u.successCount.Add(1) +func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { + if u.statusRecorder == nil { + return "" + } + peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) + if peerInfo == nil { + return "" + } + + return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) +} + +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) { resutil.SetMeta(w, "upstream", upstream.String()) - if upstreamProto != nil && upstreamProto.protocol != "" { - resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol) + if proto != "" { + resutil.SetMeta(w, "upstream_protocol", proto) } // Clear Zero bit from external responses to prevent upstream servers from @@ -372,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn if err := w.WriteMsg(rm); err != nil { logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) - return true } - - return true } func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) { - totalUpstreams := len(u.upstreamServers) + totalUpstreams := len(u.flatUpstreams()) failedCount := len(failures) failureSummary := formatFailures(failures) @@ -434,119 +633,6 @@ func edeName(code uint16) string { return fmt.Sprintf("EDE %d", code) } -// ProbeAvailability tests all upstream servers simultaneously and -// disables the resolver if none work -func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) { - u.mutex.Lock() - defer u.mutex.Unlock() - - // avoid probe if upstreams could resolve at least one query - if u.successCount.Load() > 0 { - return - } - - var success bool - var mu sync.Mutex - var wg sync.WaitGroup - - var errs *multierror.Error - for _, upstream := range u.upstreamServers { - wg.Add(1) - go func(upstream netip.AddrPort) { - defer wg.Done() - err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond) - if err != nil { - mu.Lock() - errs = multierror.Append(errs, err) - mu.Unlock() - log.Warnf("probing upstream nameserver %s: %s", upstream, err) - return - } - - mu.Lock() - success = true - mu.Unlock() - }(upstream) - } - - wg.Wait() - - select { - case <-ctx.Done(): - return - case <-u.ctx.Done(): - return - default: - } - - // didn't find a working upstream server, let's disable and try later - if !success { - u.disable(errs.ErrorOrNil()) - - if u.statusRecorder == nil { - return - } - - u.statusRecorder.PublishEvent( - proto.SystemEvent_WARNING, - proto.SystemEvent_DNS, - "All upstream servers failed (probe failed)", - "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": u.upstreamServersString()}, - ) - } -} - -// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response -func (u *upstreamResolverBase) waitUntilResponse() { - exponentialBackOff := &backoff.ExponentialBackOff{ - InitialInterval: 500 * time.Millisecond, - RandomizationFactor: 0.5, - Multiplier: 1.1, - MaxInterval: u.reactivatePeriod, - MaxElapsedTime: 0, - Stop: backoff.Stop, - Clock: backoff.SystemClock, - } - - operation := func() error { - select { - case <-u.ctx.Done(): - return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString())) - default: - } - - for _, upstream := range u.upstreamServers { - if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil { - log.Tracef("upstream check for %s: %s", upstream, err) - } else { - // at least one upstream server is available, stop probing - return nil - } - } - - log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff()) - return fmt.Errorf("upstream check call error") - } - - err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx)) - if err != nil { - if errors.Is(err, context.Canceled) { - log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString()) - } else { - log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err) - } - return - } - - log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) - u.successCount.Add(1) - u.reactivate() - u.mutex.Lock() - u.disabled = false - u.mutex.Unlock() -} - // isTimeout returns true if the given error is a network timeout error. // // Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout @@ -558,45 +644,6 @@ func isTimeout(err error) bool { return false } -func (u *upstreamResolverBase) disable(err error) { - if u.disabled { - return - } - - log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) - u.successCount.Store(0) - u.deactivate(err) - u.disabled = true - u.wg.Add(1) - go func() { - defer u.wg.Done() - u.waitUntilResponse() - }() -} - -func (u *upstreamResolverBase) upstreamServersString() string { - var servers []string - for _, server := range u.upstreamServers { - servers = append(servers, server.String()) - } - return strings.Join(servers, ", ") -} - -func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error { - mergedCtx, cancel := context.WithTimeout(baseCtx, timeout) - defer cancel() - - if externalCtx != nil { - stop2 := context.AfterFunc(externalCtx, cancel) - defer stop2() - } - - r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) - - _, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r) - return err -} - // clientUDPMaxSize returns the maximum UDP response size the client accepts. func clientUDPMaxSize(r *dns.Msg) int { if opt := r.IsEdns0(); opt != nil { @@ -608,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int { // ExchangeWithFallback exchanges a DNS message with the upstream server. // It first tries to use UDP, and if it is truncated, it falls back to TCP. // If the inbound request came over TCP (via context), it skips the UDP attempt. -// If the passed context is nil, this will use Exchange instead of ExchangeContext. func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { // If the request came in over TCP, go straight to TCP upstream. if dnsProtocolFromContext(ctx) == protoTCP { - tcpClient := *client - tcpClient.Net = protoTCP - rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream) + rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } @@ -634,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u opt.SetUDPSize(maxUDPPayload) } - var ( - rm *dns.Msg - t time.Duration - err error - ) - - if ctx == nil { - rm, t, err = client.Exchange(r, upstream) - } else { - rm, t, err = client.ExchangeContext(ctx, r, upstream) - } - + rm, t, err := client.ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with udp: %w", err) } @@ -659,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u // data than the client's buffer, we could truncate locally and skip // the TCP retry. - tcpClient := *client - tcpClient.Net = protoTCP - - if ctx == nil { - rm, t, err = tcpClient.Exchange(r, upstream) - } else { - rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream) - } - + rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } @@ -681,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } +// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a +// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on +// the tunnel interface), it is converted to the equivalent *net.TCPAddr +// so net.Dialer doesn't reject the TCP dial with "mismatched local +// address type". +func toTCPClient(c *dns.Client) *dns.Client { + tcp := *c + tcp.Net = protoTCP + if tcp.Dialer == nil { + return &tcp + } + d := *tcp.Dialer + if ua, ok := d.LocalAddr.(*net.UDPAddr); ok { + d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone} + } + tcp.Dialer = &d + return &tcp +} + // ExchangeWithNetstack performs a DNS exchange using netstack for dialing. // This is needed when netstack is enabled to reach peer IPs through the tunnel. func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) { @@ -822,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { return bestMatch } -func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { - if u.statusRecorder == nil { - return "" +// haMapRouteCount returns the total number of routes across all HA +// groups in the map. route.HAMap is keyed by HAUniqueID with slices of +// routes per key, so len(hm) is the number of HA groups, not routes. +func haMapRouteCount(hm route.HAMap) int { + total := 0 + for _, routes := range hm { + total += len(routes) } - - peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) - if peerInfo == nil { - return "" - } - - return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) + return total +} + +// haMapContains checks whether ip is covered by any concrete prefix in +// the HA map. haveDynamic is reported separately: dynamic (domain-based) +// routes carry a placeholder Network that can't be prefix-checked, so we +// can't know at this point whether ip is reached through one. Callers +// decide how to interpret the unknown: health projection treats it as +// "possibly routed" to avoid emitting false-positive warnings during +// startup, while iOS dial selection requires a concrete match before +// binding to the tunnel. +func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) { + for _, routes := range hm { + for _, r := range routes { + if r.IsDynamic() { + haveDynamic = true + continue + } + if r.Network.Contains(ip) { + return true, haveDynamic + } + } + } + return false, haveDynamic } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 988adb7d2..f7ab48b10 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" nbnet "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolver struct { @@ -26,9 +27,9 @@ func newUpstreamResolver( _ WGIface, statusRecorder *peer.Status, hostsDNSHolder *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) c := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, hostsDNSHolder: hostsDNSHolder, diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 910c3779e..dc841757b 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -12,6 +12,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolver struct { @@ -24,9 +25,9 @@ func newUpstreamResolver( wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) nonIOS := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, nsNet: wgIface.GetNet(), diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 0e04742a0..b989bf0f9 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -15,6 +15,7 @@ import ( "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolverIOS struct { @@ -27,9 +28,9 @@ func newUpstreamResolver( wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolverIOS, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, @@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * upstreamIP = upstreamIP.Unmap() } addr := u.wgIface.Address() + var routed bool + if u.selectedRoutes != nil { + // Only a concrete prefix match binds to the tunnel: dialing + // through a private client for an upstream we can't prove is + // routed would break public resolvers. + routed, _ = haMapContains(u.selectedRoutes(), upstreamIP) + } needsPrivate := addr.Network.Contains(upstreamIP) || addr.IPv6Net.Contains(upstreamIP) || - (u.routeMatch != nil && u.routeMatch(upstreamIP)) + routed if needsPrivate { log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout) @@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } } - // Cannot use client.ExchangeContext because it overwrites our Dialer - return ExchangeWithFallback(nil, client, r, upstream) + return ExchangeWithFallback(ctx, client, r, upstream) } // GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface. diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index d6aec05ca..8b3c589f1 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "strings" + "sync/atomic" "testing" "time" @@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())) } } - resolver.upstreamServers = servers + resolver.addRace(servers) resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { cancel() @@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) { return "", nil } -type mockUpstreamResolver struct { - r *dns.Msg - rtt time.Duration - err error -} - -// exchange mock implementation of exchange from upstreamResolver -func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { - return c.r, c.rtt, c.err -} - type mockUpstreamResponse struct { - msg *dns.Msg - err error + msg *dns.Msg + err error + delay time.Duration } type mockUpstreamResolverPerServer struct { @@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct { rtt time.Duration } -func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { - if r, ok := c.responses[upstream]; ok { - return r.msg, c.rtt, r.err +func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { + r, ok := c.responses[upstream] + if !ok { + return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) } - return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) -} - -func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { - mockClient := &mockUpstreamResolver{ - err: dns.ErrTime, - r: new(dns.Msg), - rtt: time.Millisecond, - } - - resolver := &upstreamResolverBase{ - ctx: context.TODO(), - upstreamClient: mockClient, - upstreamTimeout: UpstreamTimeout, - reactivatePeriod: time.Microsecond * 100, - } - addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection - resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} - - failed := false - resolver.deactivate = func(error) { - failed = true - // After deactivation, make the mock client work again - mockClient.err = nil - } - - reactivated := false - resolver.reactivate = func() { - reactivated = true - } - - resolver.ProbeAvailability(context.TODO()) - - if !failed { - t.Errorf("expected that resolving was deactivated") - return - } - - if !resolver.disabled { - t.Errorf("resolver should be Disabled") - return - } - - time.Sleep(time.Millisecond * 200) - - if !reactivated { - t.Errorf("expected that resolving was reactivated") - return - } - - if resolver.disabled { - t.Errorf("should be enabled") + if r.delay > 0 { + select { + case <-time.After(r.delay): + case <-ctx.Done(): + return nil, c.rtt, ctx.Err() + } } + return r.msg, c.rtt, r.err } func TestUpstreamResolver_Failover(t *testing.T) { @@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: trackingClient, - upstreamServers: []netip.AddrPort{upstream1, upstream2}, upstreamTimeout: UpstreamTimeout, } + resolver.addRace([]netip.AddrPort{upstream1, upstream2}) var responseMSG *dns.Msg responseWriter := &test.MockResponseWriter{ @@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: mockClient, - upstreamServers: []netip.AddrPort{upstream}, upstreamTimeout: UpstreamTimeout, } + resolver.addRace([]netip.AddrPort{upstream}) var responseMSG *dns.Msg responseWriter := &test.MockResponseWriter{ @@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL") } +// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups +// configured for the same domain, with one broken group. The merge+race +// path should answer as fast as the working group and not pay the timeout +// of the broken one on every query. +func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) { + broken := netip.MustParseAddrPort("192.0.2.1:53") + working := netip.MustParseAddrPort("192.0.2.2:53") + successAnswer := "192.0.2.100" + timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")} + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + // Force the broken upstream to only unblock via timeout / + // cancellation so the assertion below can't pass if races + // were run serially. + broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond}, + working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: 250 * time.Millisecond, + } + resolver.addRace([]netip.AddrPort{broken}) + resolver.addRace([]netip.AddrPort{working}) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + start := time.Now() + resolver.ServeDNS(responseWriter, inputMSG) + elapsed := time.Since(start) + + require.NotNil(t, responseMSG, "should write a response") + assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode) + require.NotEmpty(t, responseMSG.Answer) + assert.Contains(t, responseMSG.Answer[0].String(), successAnswer) + // Working group answers in a single RTT; the broken group's + // timeout (100ms) must not block the response. + assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout") +} + +// TestUpstreamResolver_AllGroupsFail checks that when every group fails the +// resolver returns SERVFAIL rather than leaking a partial response. +func TestUpstreamResolver_AllGroupsFail(t *testing.T) { + a := netip.MustParseAddrPort("192.0.2.1:53") + b := netip.MustParseAddrPort("192.0.2.2:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: UpstreamTimeout, + } + resolver.addRace([]netip.AddrPort{a}) + resolver.addRace([]netip.AddrPort{b}) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA)) + require.NotNil(t, responseMSG) + assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode) +} + +// TestUpstreamResolver_HealthTracking verifies that query-path results are +// recorded into per-upstream health, which is what projects back to +// NSGroupState for status reporting. +func TestUpstreamResolver_HealthTracking(t *testing.T) { + ok := netip.MustParseAddrPort("192.0.2.10:53") + bad := netip.MustParseAddrPort("192.0.2.11:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")}, + bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: UpstreamTimeout, + } + resolver.addRace([]netip.AddrPort{ok, bad}) + + responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }} + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA)) + + health := resolver.UpstreamHealth() + require.Contains(t, health, ok) + assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set") + assert.Empty(t, health[ok].LastErr) + + // bad upstream was never tried because ok answered first; its health + // should remain unset. + assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers") +} + func TestFormatFailures(t *testing.T) { testCases := []struct { name string @@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { // Verify that a client EDNS0 larger than our MTU-derived limit gets // capped in the outgoing request so the upstream doesn't send a // response larger than our read buffer. - var receivedUDPSize uint16 + var receivedUDPSize atomic.Uint32 udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { if opt := r.IsEdns0(); opt != nil { - receivedUDPSize = opt.UDPSize() + receivedUDPSize.Store(uint32(opt.UDPSize())) } m := new(dns.Msg) m.SetReply(r) @@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { require.NotNil(t, rm) expectedMax := uint16(currentMTU - ipUDPHeaderSize) - assert.Equal(t, expectedMax, receivedUDPSize, + assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()), "upstream should see capped EDNS0, not the client's 4096") } @@ -874,7 +951,7 @@ func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: tracking, - upstreamServers: []netip.AddrPort{upstream1, upstream2}, + upstreamServers: []upstreamRace{{upstream1, upstream2}}, upstreamTimeout: UpstreamTimeout, } diff --git a/client/internal/engine.go b/client/internal/engine.go index 66fe6056b..3bd0d4621 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -512,16 +512,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) - e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool { - for _, routes := range e.routeManager.GetSelectedClientRoutes() { - for _, r := range routes { - if r.Network.Contains(ip) { - return true - } - } - } - return false - }) + e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes) if err = e.wgInterfaceCreate(); err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) @@ -1386,9 +1377,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.networkSerial = serial - // Test received (upstream) servers for availability right away instead of upon usage. - // If no server of a server group responds this will disable the respective handler and retry later. - go e.dnsServer.ProbeAvailability() return nil } @@ -1932,7 +1920,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil case "ios": - dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS) + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) return dnsServer, nil default: diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 907f1f592..839ec14c0 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -53,6 +53,7 @@ type Manager interface { GetRouteSelector() *routeselector.RouteSelector GetClientRoutes() route.HAMap GetSelectedClientRoutes() route.HAMap + GetActiveClientRoutes() route.HAMap GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap { return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) } +// GetActiveClientRoutes returns the subset of selected client routes +// that are currently reachable: the route's peer is Connected and is +// the one actively carrying the route (not just an HA sibling). +func (m *DefaultManager) GetActiveClientRoutes() route.HAMap { + m.mux.Lock() + selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) + recorder := m.statusRecorder + m.mux.Unlock() + + if recorder == nil { + return selected + } + + out := make(route.HAMap, len(selected)) + for id, routes := range selected { + for _, r := range routes { + st, err := recorder.GetPeer(r.Peer) + if err != nil { + continue + } + if st.ConnStatus != peer.StatusConnected { + continue + } + if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute { + continue + } + out[id] = routes + break + } + } + return out +} + // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { m.mux.Lock() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 66b5e30dd..937314995 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -19,6 +19,7 @@ type MockManager struct { GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap GetSelectedClientRoutesFunc func() route.HAMap + GetActiveClientRoutesFunc func() route.HAMap GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route StopFunc func(manager *statemanager.Manager) } @@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap { return nil } +// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface +func (m *MockManager) GetActiveClientRoutes() route.HAMap { + if m.GetActiveClientRoutesFunc != nil { + return m.GetActiveClientRoutesFunc() + } + return nil +} + // GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { if m.GetClientRoutesWithNetIDFunc != nil { diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 33f5ab1b0..bafbb0031 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { cfg.WgIface = interfaceName c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - hostDNS := []netip.AddrPort{ - netip.MustParseAddrPort("9.9.9.9:53"), - netip.MustParseAddrPort("149.112.112.112:53"), - } - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile) + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } // Stop the internal client and free the resources