From e916f12cca508dfea584e7b72cf99a135acebc2b Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 15 May 2026 19:13:44 +0200 Subject: [PATCH] [proxy] auth token generation on mapping (#6157) * [management / proxy] auth token generation on mapping * fix tests --- management/internals/shared/grpc/proxy.go | 15 +++--- .../shared/grpc/proxy_snapshot_test.go | 53 +++++++++++++++++++ .../shared/grpc/validate_session_test.go | 14 +++-- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 9e5027547..eada2d86a 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -394,6 +394,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec if end > len(mappings) { end = len(mappings) } + for _, m := range mappings[i:end] { + token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL()) + if err != nil { + return fmt.Errorf("generate auth token for service %s: %w", m.Id, err) + } + m.AuthToken = token + } if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ Mapping: mappings[i:end], InitialSyncComplete: end == len(mappings), @@ -425,18 +432,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * return nil, fmt.Errorf("get services from store: %w", err) } + oidcCfg := s.GetOIDCValidationConfig() var mappings []*proto.ProxyMapping for _, service := range services { if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address { continue } - token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL()) - if err != nil { - return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err) - } - - m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + m := service.ToProtoMapping(rpservice.Create, "", oidcCfg) if !proxyAcceptsMapping(conn, m) { continue } diff --git a/management/internals/shared/grpc/proxy_snapshot_test.go b/management/internals/shared/grpc/proxy_snapshot_test.go index e0c7425c5..68d2ecfd1 100644 --- a/management/internals/shared/grpc/proxy_snapshot_test.go +++ b/management/internals/shared/grpc/proxy_snapshot_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) { assert.Empty(t, stream.messages[0].Mapping) assert.True(t, stream.messages[0].InitialSyncComplete) } + +type hookingStream struct { + grpc.ServerStream + onSend func(*proto.GetMappingUpdateResponse) +} + +func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error { + if s.onSend != nil { + s.onSend(m) + } + return nil +} + +func (s *hookingStream) Context() context.Context { return context.Background() } +func (s *hookingStream) SetHeader(metadata.MD) error { return nil } +func (s *hookingStream) SendHeader(metadata.MD) error { return nil } +func (s *hookingStream) SetTrailer(metadata.MD) {} +func (s *hookingStream) SendMsg(any) error { return nil } +func (s *hookingStream) RecvMsg(any) error { return nil } + +func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 6 + const ttl = 100 * time.Millisecond + const sendDelay = 200 * time.Millisecond + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + s.tokenTTL = ttl + + var validateErrs []error + stream := &hookingStream{ + onSend: func(resp *proto.GetMappingUpdateResponse) { + for _, m := range resp.Mapping { + if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil { + validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err)) + } + } + time.Sleep(sendDelay) + }, + } + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Empty(t, validateErrs, + "tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness") +} diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 6cd95f988..7b7ffcfb2 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -326,17 +326,25 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, return nil, nil } +func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error { +func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error { return nil } -func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error { return nil }