mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
327 lines
9.4 KiB
Go
327 lines
9.4 KiB
Go
package grpc
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
|
"github.com/netbirdio/netbird/shared/management/proto"
|
|
)
|
|
|
|
type testProxyController struct {
|
|
mu sync.Mutex
|
|
clusterProxies map[string]map[string]struct{}
|
|
}
|
|
|
|
func newTestProxyController() *testProxyController {
|
|
return &testProxyController{
|
|
clusterProxies: make(map[string]map[string]struct{}),
|
|
}
|
|
}
|
|
|
|
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
|
|
}
|
|
|
|
func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
|
return proxy.OIDCValidationConfig{}
|
|
}
|
|
|
|
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if _, ok := c.clusterProxies[clusterAddr]; !ok {
|
|
c.clusterProxies[clusterAddr] = make(map[string]struct{})
|
|
}
|
|
c.clusterProxies[clusterAddr][proxyID] = struct{}{}
|
|
return nil
|
|
}
|
|
|
|
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if proxies, ok := c.clusterProxies[clusterAddr]; ok {
|
|
delete(proxies, proxyID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
proxies, ok := c.clusterProxies[clusterAddr]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
result := make([]string, 0, len(proxies))
|
|
for id := range proxies {
|
|
result = append(result, id)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
|
// and returns the channel where messages will be received.
|
|
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
|
|
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
|
conn := &proxyConnection{
|
|
proxyID: proxyID,
|
|
address: clusterAddr,
|
|
sendChan: ch,
|
|
}
|
|
s.connectedProxies.Store(proxyID, conn)
|
|
|
|
_ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
|
|
|
|
return ch
|
|
}
|
|
|
|
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
|
select {
|
|
case msg := <-ch:
|
|
return msg
|
|
case <-time.After(time.Second):
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
|
ctx := context.Background()
|
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
|
require.NoError(t, err)
|
|
|
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
|
require.NoError(t, err)
|
|
|
|
s := &ProxyServiceServer{
|
|
tokenStore: tokenStore,
|
|
pkceVerifierStore: pkceStore,
|
|
}
|
|
s.SetProxyController(newTestProxyController())
|
|
|
|
const cluster = "proxy.example.com"
|
|
const numProxies = 3
|
|
|
|
channels := make([]chan *proto.GetMappingUpdateResponse, numProxies)
|
|
for i := range numProxies {
|
|
id := "proxy-" + string(rune('a'+i))
|
|
channels[i] = registerFakeProxy(s, id, cluster)
|
|
}
|
|
|
|
mapping := &proto.ProxyMapping{
|
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
|
Id: "service-1",
|
|
AccountId: "account-1",
|
|
Domain: "test.example.com",
|
|
Path: []*proto.PathMapping{
|
|
{Path: "/", Target: "http://10.0.0.1:8080/"},
|
|
},
|
|
}
|
|
|
|
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
|
|
|
|
tokens := make([]string, numProxies)
|
|
for i, ch := range channels {
|
|
resp := drainChannel(ch)
|
|
require.NotNil(t, resp, "proxy %d should receive a message", i)
|
|
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
|
|
msg := resp.Mapping[0]
|
|
assert.Equal(t, mapping.Domain, msg.Domain)
|
|
assert.Equal(t, mapping.Id, msg.Id)
|
|
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
|
tokens[i] = msg.AuthToken
|
|
}
|
|
|
|
// All tokens must be unique
|
|
tokenSet := make(map[string]struct{})
|
|
for i, tok := range tokens {
|
|
_, exists := tokenSet[tok]
|
|
assert.False(t, exists, "proxy %d got duplicate token", i)
|
|
tokenSet[tok] = struct{}{}
|
|
}
|
|
|
|
// Each token must be independently consumable
|
|
for i, tok := range tokens {
|
|
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
|
|
assert.NoError(t, err, "proxy %d token should validate successfully", i)
|
|
}
|
|
}
|
|
|
|
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
|
require.NoError(t, err)
|
|
|
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
|
require.NoError(t, err)
|
|
|
|
s := &ProxyServiceServer{
|
|
tokenStore: tokenStore,
|
|
pkceVerifierStore: pkceStore,
|
|
}
|
|
s.SetProxyController(newTestProxyController())
|
|
|
|
const cluster = "proxy.example.com"
|
|
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
|
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
|
|
|
mapping := &proto.ProxyMapping{
|
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
|
Id: "service-1",
|
|
AccountId: "account-1",
|
|
Domain: "test.example.com",
|
|
}
|
|
|
|
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
|
|
|
|
resp1 := drainChannel(ch1)
|
|
resp2 := drainChannel(ch2)
|
|
require.NotNil(t, resp1)
|
|
require.NotNil(t, resp2)
|
|
require.Len(t, resp1.Mapping, 1)
|
|
require.Len(t, resp2.Mapping, 1)
|
|
|
|
// Delete operations should not generate tokens
|
|
assert.Empty(t, resp1.Mapping[0].AuthToken)
|
|
assert.Empty(t, resp2.Mapping[0].AuthToken)
|
|
}
|
|
|
|
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
|
ctx := context.Background()
|
|
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
|
require.NoError(t, err)
|
|
|
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
|
require.NoError(t, err)
|
|
|
|
s := &ProxyServiceServer{
|
|
tokenStore: tokenStore,
|
|
pkceVerifierStore: pkceStore,
|
|
}
|
|
s.SetProxyController(newTestProxyController())
|
|
|
|
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
|
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
|
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
|
|
|
mapping := &proto.ProxyMapping{
|
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
|
Id: "service-1",
|
|
AccountId: "account-1",
|
|
Domain: "test.example.com",
|
|
}
|
|
|
|
update := &proto.GetMappingUpdateResponse{
|
|
Mapping: []*proto.ProxyMapping{mapping},
|
|
}
|
|
|
|
s.SendServiceUpdate(update)
|
|
|
|
resp1 := drainChannel(ch1)
|
|
resp2 := drainChannel(ch2)
|
|
require.NotNil(t, resp1)
|
|
require.NotNil(t, resp2)
|
|
require.Len(t, resp1.Mapping, 1)
|
|
require.Len(t, resp2.Mapping, 1)
|
|
|
|
msg1 := resp1.Mapping[0]
|
|
msg2 := resp2.Mapping[0]
|
|
|
|
assert.NotEmpty(t, msg1.AuthToken)
|
|
assert.NotEmpty(t, msg2.AuthToken)
|
|
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
|
|
|
|
// Both tokens should validate
|
|
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
|
|
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
|
|
}
|
|
|
|
// generateState creates a state using the same format as GetOIDCURL.
|
|
func generateState(s *ProxyServiceServer, redirectURL string) string {
|
|
nonce := make([]byte, 16)
|
|
_, _ = rand.Read(nonce)
|
|
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
|
|
|
|
payload := redirectURL + "|" + nonceB64
|
|
hmacSum := s.generateHMAC(payload)
|
|
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
|
|
}
|
|
|
|
func TestOAuthState_NeverTheSame(t *testing.T) {
|
|
ctx := context.Background()
|
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
|
require.NoError(t, err)
|
|
|
|
s := &ProxyServiceServer{
|
|
oidcConfig: ProxyOIDCConfig{
|
|
HMACKey: []byte("test-hmac-key"),
|
|
},
|
|
pkceVerifierStore: pkceStore,
|
|
}
|
|
|
|
redirectURL := "https://app.example.com/callback"
|
|
|
|
// Generate 100 states for the same redirect URL
|
|
states := make(map[string]bool)
|
|
for i := 0; i < 100; i++ {
|
|
state := generateState(s, redirectURL)
|
|
|
|
// State must have 3 parts: base64(url)|nonce|hmac
|
|
parts := strings.Split(state, "|")
|
|
require.Equal(t, 3, len(parts), "state must have 3 parts")
|
|
|
|
// State must be unique
|
|
require.False(t, states[state], "state %d is a duplicate", i)
|
|
states[state] = true
|
|
}
|
|
}
|
|
|
|
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
|
ctx := context.Background()
|
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
|
require.NoError(t, err)
|
|
|
|
s := &ProxyServiceServer{
|
|
oidcConfig: ProxyOIDCConfig{
|
|
HMACKey: []byte("test-hmac-key"),
|
|
},
|
|
pkceVerifierStore: pkceStore,
|
|
}
|
|
|
|
// Old format had only 2 parts: base64(url)|hmac
|
|
err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
|
|
require.NoError(t, err)
|
|
|
|
_, _, err = s.ValidateState("base64url|hmac")
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "invalid state format")
|
|
}
|
|
|
|
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
|
ctx := context.Background()
|
|
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
|
require.NoError(t, err)
|
|
|
|
s := &ProxyServiceServer{
|
|
oidcConfig: ProxyOIDCConfig{
|
|
HMACKey: []byte("test-hmac-key"),
|
|
},
|
|
pkceVerifierStore: pkceStore,
|
|
}
|
|
|
|
// Store with tampered HMAC
|
|
err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
|
|
require.NoError(t, err)
|
|
|
|
_, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "invalid state signature")
|
|
}
|