mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-19 23:29:56 +00:00
[proxy] auth token generation on mapping (#6157)
* [management / proxy] auth token generation on mapping * fix tests
This commit is contained in:
@@ -394,6 +394,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
if end > len(mappings) {
|
if end > len(mappings) {
|
||||||
end = len(mappings)
|
end = len(mappings)
|
||||||
}
|
}
|
||||||
|
for _, m := range mappings[i:end] {
|
||||||
|
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||||
|
}
|
||||||
|
m.AuthToken = token
|
||||||
|
}
|
||||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||||
Mapping: mappings[i:end],
|
Mapping: mappings[i:end],
|
||||||
InitialSyncComplete: end == len(mappings),
|
InitialSyncComplete: end == len(mappings),
|
||||||
@@ -425,18 +432,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
|||||||
return nil, fmt.Errorf("get services from store: %w", err)
|
return nil, fmt.Errorf("get services from store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oidcCfg := s.GetOIDCValidationConfig()
|
||||||
var mappings []*proto.ProxyMapping
|
var mappings []*proto.ProxyMapping
|
||||||
for _, service := range services {
|
for _, service := range services {
|
||||||
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
|
||||||
if !proxyAcceptsMapping(conn, m) {
|
if !proxyAcceptsMapping(conn, m) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
|||||||
assert.Empty(t, stream.messages[0].Mapping)
|
assert.Empty(t, stream.messages[0].Mapping)
|
||||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type hookingStream struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
onSend func(*proto.GetMappingUpdateResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||||
|
if s.onSend != nil {
|
||||||
|
s.onSend(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *hookingStream) Context() context.Context { return context.Background() }
|
||||||
|
func (s *hookingStream) SetHeader(metadata.MD) error { return nil }
|
||||||
|
func (s *hookingStream) SendHeader(metadata.MD) error { return nil }
|
||||||
|
func (s *hookingStream) SetTrailer(metadata.MD) {}
|
||||||
|
func (s *hookingStream) SendMsg(any) error { return nil }
|
||||||
|
func (s *hookingStream) RecvMsg(any) error { return nil }
|
||||||
|
|
||||||
|
func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) {
|
||||||
|
const cluster = "cluster.example.com"
|
||||||
|
const batchSize = 2
|
||||||
|
const totalServices = 6
|
||||||
|
const ttl = 100 * time.Millisecond
|
||||||
|
const sendDelay = 200 * time.Millisecond
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mgr := rpservice.NewMockManager(ctrl)
|
||||||
|
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||||
|
|
||||||
|
s := newSnapshotTestServer(t, batchSize)
|
||||||
|
s.serviceManager = mgr
|
||||||
|
s.tokenTTL = ttl
|
||||||
|
|
||||||
|
var validateErrs []error
|
||||||
|
stream := &hookingStream{
|
||||||
|
onSend: func(resp *proto.GetMappingUpdateResponse) {
|
||||||
|
for _, m := range resp.Mapping {
|
||||||
|
if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil {
|
||||||
|
validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(sendDelay)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||||
|
|
||||||
|
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||||
|
require.Empty(t, validateErrs,
|
||||||
|
"tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness")
|
||||||
|
}
|
||||||
|
|||||||
@@ -326,17 +326,25 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context,
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type testValidateSessionProxyManager struct{}
|
type testValidateSessionProxyManager struct{}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
|
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
|
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user